Skip to content

Commit

Permalink
Add pure S4D (#868)
Browse files Browse the repository at this point in the history
* adding pure S4D

* typo

* rename method

* typo

---------

Co-authored-by: PhilippPlank <[email protected]>
  • Loading branch information
smm-ncl and PhilippPlank committed Jul 16, 2024
1 parent 6a98a52 commit 11c8b02
Show file tree
Hide file tree
Showing 8 changed files with 380 additions and 17 deletions.
66 changes: 65 additions & 1 deletion src/lava/proc/s4d/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,76 @@
from lava.proc.sdn.models import AbstractSigmaDeltaModel
from lava.magma.core.decorator import implements, requires, tag
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.proc.s4d.process import SigmaS4dDelta, SigmaS4dDeltaLayer
from lava.proc.s4d.process import SigmaS4dDelta, SigmaS4dDeltaLayer, S4d
from lava.magma.core.resources import CPU
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.model.sub.model import AbstractSubProcessModel
from lava.proc.sparse.process import Sparse
from lava.magma.core.model.py.model import PyLoihiProcessModel


@implements(proc=S4d, protocol=LoihiProtocol)
@requires(CPU)
@tag('floating_pt')
class S4dModel(PyLoihiProcessModel):
a_in = LavaPyType(PyInPort.VEC_DENSE, float)
s_out = LavaPyType(PyOutPort.VEC_DENSE, float)
s4_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3)

# S4 variables
s4_state: np.ndarray = LavaPyType(np.ndarray, complex)
a: np.ndarray = LavaPyType(np.ndarray, complex)
b: np.ndarray = LavaPyType(np.ndarray, complex)
c: np.ndarray = LavaPyType(np.ndarray, complex)

def __init__(self, proc_params: Dict[str, Any]) -> None:
"""
Neuron model that implements S4D
(as described by Gu et al., 2022) dynamics.
Relevant parameters in proc_params
--------------------------
a: np.ndarray
Diagonal elements of the state matrix of the discretized S4D model.
b: np.ndarray
Diagonal elements of the input matrix of the discretized S4D model.
c: np.ndarray
Diagonal elements of the output matrix of the discretized S4D model.
s4_state: np.ndarray
State vector of the S4D discretized model.
"""
super().__init__(proc_params)
self.a = self.proc_params['a']
self.b = self.proc_params['b']
self.c = self.proc_params['c']
self.s4_state = self.proc_params['s4_state']

def run_spk(self) -> None:
"""Performs S4D dynamics.
This function simulates the behavior of a linear time-invariant system
with diagonalized state-space representation.
(For reference see Gu et al., 2022)
The state-space equations are given by:
s4_state_{k+1} = A * s4_state_k + B * input_k
act_k = C * s4_state_k
where:
- s4_state_k is the state vector at time step k,
- input_k is the input vector at time step k,
- act_k is the output vector at time step k,
- A is the diagonal state matrix,
- B is the diagonal input matrix,
- C is the diagonal output matrix.
The function computes the next output step of the
system for the given input signal.
"""
inp = self.a_in.recv()
self.s4_state = (self.s4_state * self.a + inp * self.b)
self.s_out.send(np.real(self.c * self.s4_state * 2))


class AbstractSigmaS4dDeltaModel(AbstractSigmaDeltaModel):
Expand Down
70 changes: 70 additions & 0 deletions src/lava/proc/s4d/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,76 @@
from lava.proc.sdn.process import ActivationMode, SigmaDelta


class S4d(AbstractProcess):
def __init__(
self,
shape: ty.Tuple[int, ...],
a: float,
b: float,
c: float,
s4_state: ty.Optional[int] = 0,
s4_exp: ty.Optional[int] = 0) -> None:
"""
Neuron process that implements S4D (described by
Gu et al., 2022) dynamics.
This process simulates the behavior of a linear time-invariant system
with diagonal state-space representation.
The state-space equations are given by:
s4_state_{k+1} = A * s4_state_k + B * inp_k
act_k = C * s4_state_k
where:
- s4_state_k is the state vector at time step k,
- inp_k is the input vector at time step k,
- act_k is the output vector at time step k,
- A is the diagonal state matrix,
- B is the diagonal input matrix,
- C is the diagonal output matrix.
Parameters
----------
shape: Tuple
Shape of the sigma process.
vth: int or float
Threshold of the delta encoder.
a: np.ndarray
Diagonal elements of the state matrix of the S4D model.
b: np.ndarray
Diagonal elements of the input matrix of the S4D model.
c: np.ndarray
Diagonal elements of the output matrix of the S4D model.
s4_state: int or float
Initial state of the S4D model.
s4_exp: int
Scaling exponent with base 2 for the S4 state variables.
Note: This should only be used for nc models.
Default is 0.
"""

super().__init__(shape=shape,
a=a,
b=b,
c=c,
s4_state=s4_state,
s4_exp=s4_exp)
# Ports
self.a_in = InPort(shape=shape)
self.s_out = OutPort(shape=shape)

# Variables for S4
self.a = Var(shape=shape, init=a)
self.b = Var(shape=shape, init=b)
self.c = Var(shape=shape, init=c)
self.s4_state = Var(shape=shape, init=0)
self.s4_exp = Var(shape=(1,), init=s4_exp)

@property
def shape(self) -> ty.Tuple[int, ...]:
"""Return shape of the Process."""
return self.proc_params['shape']


class SigmaS4dDelta(SigmaDelta, AbstractProcess):
def __init__(
self,
Expand Down
3 changes: 3 additions & 0 deletions tests/lava/proc/s4d/dA_complex.npy
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/lava/proc/s4d/dB_complex.npy
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/lava/proc/s4d/dC_complex.npy
Git LFS file not shown
88 changes: 87 additions & 1 deletion tests/lava/proc/s4d/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,104 @@
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/


import unittest
import numpy as np
from typing import Tuple
import lava.proc.io as io
from lava.magma.core.run_conditions import RunSteps
from lava.proc.sdn.process import ActivationMode, SigmaDelta
from lava.proc.s4d.process import SigmaS4dDelta, SigmaS4dDeltaLayer
from lava.proc.s4d.process import S4d, SigmaS4dDelta, SigmaS4dDeltaLayer
from lava.proc.sparse.process import Sparse
from lava.magma.core.run_configs import Loihi2SimCfg
from tests.lava.proc.s4d.utils import get_coefficients, run_original_model


class TestS4DModel(unittest.TestCase):
"""Tests for S4d neuron"""
def run_in_sim(
self,
inp: np.ndarray,
a: np.ndarray,
b: np.ndarray,
c: np.ndarray,
num_steps: int,
model_dim: int,
d_states: int,
) -> Tuple[np.ndarray]:

# Get S4D matrices
a = a[:model_dim * d_states]
b = b[:model_dim * d_states]
c = c[:model_dim * d_states]

# Setup network: input -> expansion -> S4D neuron -> output
kron_matrix = np.kron(np.eye(model_dim), np.ones((d_states, )))
spiker = io.source.RingBuffer(data=inp)
sparse_1 = Sparse(weights=kron_matrix.T, num_message_bits=24)
neuron = S4d(shape=((model_dim * d_states,)),
a=a,
b=b,
c=c)

receiver = io.sink.RingBuffer(buffer=num_steps,
shape=(model_dim * d_states,))
spiker.s_out.connect(sparse_1.s_in)
sparse_1.a_out.connect(neuron.a_in)
neuron.s_out.connect(receiver.a_in)

run_cfg = Loihi2SimCfg(select_tag="floating_pt")
neuron.run(
condition=RunSteps(num_steps=num_steps), run_cfg=run_cfg)
received_data_sim = receiver.data.get()
neuron.stop()

return received_data_sim

def compare_s4d_model_to_original_equations(self,
model_dim: int = 10,
d_states: int = 5,
n_steps: int = 5,
inp_exp: int = 5,
is_real: bool = False) -> None:

"""Asserts that the floating point lava simulation for S4d outputs
exactly the same values as the original equations."""
a, b, c = get_coefficients(is_real=is_real)
np.random.seed(0)
inp = (np.random.random((model_dim, n_steps)) * 2**inp_exp).astype(int)
out_lava = self.run_in_sim(inp=inp,
num_steps=n_steps,
a=a,
b=b,
c=c,
model_dim=model_dim,
d_states=d_states)
out_original_equations = run_original_model(inp=inp,
num_steps=n_steps,
model_dim=model_dim,
d_states=d_states,
a=a,
b=b,
c=c,
perform_reduction=False)

np.testing.assert_array_equal(out_original_equations[:, :-1],
out_lava[:, 1:])

def test_s4d_real_model_single_hidden_state(self) -> None:
self.compare_s4d_model_to_original_equations(is_real=True, d_states=1)

def test_s4d_real_model_multiple_hidden_state(self) -> None:
self.compare_s4d_model_to_original_equations(is_real=True, d_states=5)

def test_s4d_complex_model_single_hidden_state(self) -> None:
self.compare_s4d_model_to_original_equations(is_real=False, d_states=1)

def test_s4d_complex_model_multiple_hidden_state(self) -> None:
self.compare_s4d_model_to_original_equations(is_real=False, d_states=5)


class TestSigmaS4DDeltaModels(unittest.TestCase):
"""Tests for SigmaS4Delta neuron"""
def run_in_lava(
Expand Down
28 changes: 26 additions & 2 deletions tests/lava/proc/s4d/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,31 @@

import unittest
import numpy as np
from lava.proc.s4d.process import SigmaS4dDelta, SigmaS4dDeltaLayer
from lava.proc.s4d.process import SigmaS4dDelta, SigmaS4dDeltaLayer, S4d


class TestS4dProcess(unittest.TestCase):
"""Tests for S4d Class"""

def test_init(self) -> None:
"""Tests instantiation of S4d"""
shape = 10
s4_exp = 12
a = np.ones(shape) * 0.5
b = np.ones(shape) * 0.8
c = np.ones(shape) * 0.9
s4d = S4d(shape=(shape,),
s4_exp=s4_exp,
a=a,
b=b,
c=c)

self.assertEqual(s4d.shape, (shape,))
self.assertEqual(s4d.s4_exp.init, s4_exp)
np.testing.assert_array_equal(s4d.a.init, a)
np.testing.assert_array_equal(s4d.b.init, b)
np.testing.assert_array_equal(s4d.c.init, c)
self.assertEqual(s4d.s4_state.init, 0)


class TestSigmaS4dDeltaProcess(unittest.TestCase):
Expand Down Expand Up @@ -37,7 +61,7 @@ def test_init(self) -> None:
self.assertEqual(sigma_s4_delta.state_exp.init, state_exp)
self.assertEqual(sigma_s4_delta.s4_state.init, 0)

# default sigmadelta params - inherited from SigmaDelta class
# default sigma-delta params - inherited from SigmaDelta class
self.assertEqual(sigma_s4_delta.cum_error.init, False)
self.assertEqual(sigma_s4_delta.spike_exp.init, 0)
self.assertEqual(sigma_s4_delta.bias.init, 0)
Expand Down
Loading

0 comments on commit 11c8b02

Please sign in to comment.