Skip to content

Commit

Permalink
Add the models and process of conv_in_time in src/lava/proc/conv_in_t…
Browse files Browse the repository at this point in the history
…ime (#833)

* add the models and process of conv_in_time in src/lava/proc/conv_in_time

* remove unused library

* remove Trailing whitespace

* add unittest for conv in time and related pytorch ground truth

* add fixed_pt version of conv in time

* change input to spike_input

* add from lava.proc.conv import utils

* remove unwanted comments

* fixed some linting errors

* Start all comments with upper case character & change the year for all copyright headers to 2024 for new files

* remove whitespace

* continuation line under-indented

* Trailing whitespace

* shorten variables names

---------

Co-authored-by: bamsumit <[email protected]>
Co-authored-by: PhilippPlank <[email protected]>
  • Loading branch information
3 people authored Apr 25, 2024
1 parent fae3ea1 commit 2f3e0f8
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 0 deletions.
99 changes: 99 additions & 0 deletions src/lava/proc/conv_in_time/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import numpy as np

from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires, tag
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.proc.conv_in_time.process import ConvInTime
from lava.proc.conv import utils


class AbstractPyConvInTimeModel(PyLoihiProcessModel):
"""Abstract Conn In Time Process with Dense synaptic connections
which incorporates delays into the Conv Process.
"""
weights: np.ndarray = None
a_buff: np.ndarray = None
kernel_size: int = None

num_message_bits: np.ndarray = LavaPyType(np.ndarray, np.int8, precision=5)

def calc_act(self, s_in) -> np.ndarray:
"""
Calculate the activation buff by inverse the order in
the kernel. Taking k=3 as an example, the a_buff will be
weights[2] * s_in, weights[1] * s_in, weights[0] * s_in
"""

# The change of the shape is shown below:
# sum([K, n_out, n_in] * [n_in,], axis=-1) = [K, n_out] -> [n_out, K]
kernel_size = self.weights.shape[0]
for i in range(kernel_size):
self.a_buff[:, i] += np.sum(
self.weights[kernel_size - i - 1] * s_in, axis=-1).T

def update_act(self, s_in):
"""
Updates the activations for the connection.
Clears first column of a_buff and rolls them to the last column.
Finally, calculates the activations for the current time step and adds
them to a_buff.
This order of operations ensures that delays of 0 correspond to
the next time step.
"""
self.a_buff[:, 0] = 0
self.a_buff = np.roll(self.a_buff, -1)
self.calc_act(s_in)

def run_spk(self):
# The a_out sent on a each timestep is a buffered value from dendritic
# accumulation at timestep t-1. This prevents deadlocking in
# networks with recurrent connectivity structures.
self.a_out.send(self.a_buff[:, 0])
if self.num_message_bits.item() > 0:
s_in = self.s_in.recv()
else:
s_in = self.s_in.recv().astype(bool)
self.update_act(s_in)


@implements(proc=ConvInTime, protocol=LoihiProtocol)
@requires(CPU)
@tag("floating_pt")
class PyConvInTimeFloat(AbstractPyConvInTimeModel):
"""Implementation of Conn In Time Process with Dense synaptic connections in
floating point precision. This short and simple ProcessModel can be used
for quick algorithmic prototyping, without engaging with the nuances of a
fixed point implementation. DelayDense incorporates delays into the Conn
Process.
"""
s_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1)
a_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)
a_buff: np.ndarray = LavaPyType(np.ndarray, float)
# The weights is a 3D matrix of form (kernel_size,
# num_flat_output_neurons, num_flat_input_neurons) in C-order (row major).
weights: np.ndarray = LavaPyType(np.ndarray, float)
num_message_bits: np.ndarray = LavaPyType(np.ndarray, int, precision=5)


@implements(proc=ConvInTime, protocol=LoihiProtocol)
@requires(CPU)
@tag("fixed_pt")
class PyConvInTimeFixed(AbstractPyConvInTimeModel):
"""Conv In Time with fixed point synapse implementation."""
s_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1)
a_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)
a_buff: np.ndarray = LavaPyType(np.ndarray, float)
# The weights is a 3D matrix of form (kernel_size,
# num_flat_output_neurons, num_flat_input_neurons) in C-order (row major).
weights: np.ndarray = LavaPyType(np.ndarray, float)
num_message_bits: np.ndarray = LavaPyType(np.ndarray, int, precision=5)

def clamp_precision(self, x: np.ndarray) -> np.ndarray:
return utils.signed_clamp(x, bits=24)
86 changes: 86 additions & 0 deletions src/lava/proc/conv_in_time/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import numpy as np
import typing as ty

from lava.magma.core.process.process import AbstractProcess, LogConfig
from lava.magma.core.process.variable import Var
from lava.magma.core.process.ports.ports import InPort, OutPort


class ConvInTime(AbstractProcess):
"""Connection Process that mimics a convolution of the incoming
events/spikes with a kernel in the time dimension. Realizes the
following abstract behavior: a_out[t] = weights[t-1] * s_in[t-1]
+ weights[t] * s_in[t] + weights[t+1] * s_in[t+1]
Parameters
----------
weights : numpy.ndarray
3D connection weight matrix of form (kernel_size,
num_flat_output_neurons, num_flat_input_neurons)
in C-order (row major).
weight_exp : int, optional
Shared weight exponent of base 2 used to scale magnitude of
weights, if needed. Mostly for fixed point implementations.
Unnecessary for floating point implementations.
Default value is 0.
num_weight_bits : int, optional
Shared weight width/precision used by weight. Mostly for fixed
point implementations. Unnecessary for floating point
implementations.
Default is for weights to use full 8 bit precision.
sign_mode : SignMode, optional
Shared indicator whether synapse is of type SignMode.NULL,
SignMode.MIXED, SignMode.EXCITATORY, or SignMode.INHIBITORY. If
SignMode.MIXED, the sign of the weight is
included in the weight bits and the fixed point weight used for
inference is scaled by 2.
Unnecessary for floating point implementations.
In the fixed point implementation, weights are scaled according to
the following equations:
w_scale = 8 - num_weight_bits + weight_exp + isMixed()
weights = weights * (2 ** w_scale)
num_message_bits : int, optional
Determines whether the Dense Process deals with the incoming
spikes as binary spikes (num_message_bits = 0) or as graded
spikes (num_message_bits > 0). Default is 0.
"""
def __init__(self,
*,
weights: np.ndarray,
name: ty.Optional[str] = None,
num_message_bits: ty.Optional[int] = 0,
log_config: ty.Optional[LogConfig] = None,
**kwargs) -> None:

super().__init__(weights=weights,
num_message_bits=num_message_bits,
name=name,
log_config=log_config,
**kwargs)

self._validate_weights(weights)
# [kernel_size, n_flat_output_neurons, n_flat_input_neurons]
shape = weights.shape
# Ports
self.s_in = InPort(shape=(shape[2],))
self.a_out = OutPort(shape=(shape[1],))

# Variables
self.weights = Var(shape=shape, init=weights)
self.a_buff = Var(shape=(shape[1], shape[0]), init=0)
self.num_message_bits = Var(shape=(1,), init=num_message_bits)

@staticmethod
def _validate_weights(weights: np.ndarray) -> None:
if len(np.shape(weights)) != 3:
raise ValueError("Dense Process 'weights' expects a 3D matrix, "
f"got {weights}.")
Empty file.
3 changes: 3 additions & 0 deletions tests/lava/proc/conv_in_time/gts/q_weights.npy
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/lava/proc/conv_in_time/gts/spike_input.npy
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/lava/proc/conv_in_time/gts/torch_output.npy
Git LFS file not shown
80 changes: 80 additions & 0 deletions tests/lava/proc/conv_in_time/test_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import unittest
import os
import numpy as np
from lava.proc.conv_in_time.process import ConvInTime
from lava.proc import io

from lava.magma.core.run_conditions import RunSteps
from lava.magma.core.run_configs import Loihi1SimCfg
from lava.proc.conv import utils

if utils.TORCH_IS_AVAILABLE:
import torch
import torch.nn as nn
compare = True
# In this case, the test compares against random torch ground truth
else:
compare = False
# In this case, the test compares against saved torch ground truth


class TestConvInTimeProcess(unittest.TestCase):
"""Tests for Conv class"""
def test_init(self) -> None:
"""Tests instantiation of Conv In Time"""
num_steps = 10
kernel_size = 3
n_in = 2
n_out = 5
if compare:
spike_input = np.random.choice(
[0, 1],
size=(n_in, num_steps))
weights = np.random.randint(256, size=[kernel_size,
n_out,
n_in]) - 128
else:
spike_input = np.load(os.path.join(os.path.dirname(__file__),
"gts/spike_input.npy"))
weights = np.load(os.path.join(os.path.dirname(__file__),
"gts/q_weights.npy"))
sender = io.source.RingBuffer(data=spike_input)
conv_in_time = ConvInTime(weights=weights, name='conv_in_time')

receiver = io.sink.RingBuffer(
shape=(n_out,),
buffer=num_steps + 1)

sender.s_out.connect(conv_in_time.s_in)
conv_in_time.a_out.connect(receiver.a_in)

run_condition = RunSteps(num_steps=num_steps + 1)
run_cfg = Loihi1SimCfg(select_tag="floating_pt")

conv_in_time.run(condition=run_condition, run_cfg=run_cfg)
output = receiver.data.get()
conv_in_time.stop()

if compare:
tensor_input = torch.tensor(spike_input, dtype=torch.float32)
tensor_weights = torch.tensor(weights, dtype=torch.float32)
conv_layer = nn.Conv1d(
in_channels=n_in,
out_channels=n_out,
kernel_size=kernel_size, bias=False)
# Permute the weights to match the torch format
conv_layer.weight = nn.Parameter(tensor_weights.permute(1, 2, 0))
torch_output = conv_layer(
tensor_input.unsqueeze(0)).squeeze(0).detach().numpy()
else:
torch_output = np.load(os.path.join(os.path.dirname(__file__),
"gts/torch_output.npy"))

self.assertEqual(output.shape, (n_out, num_steps + 1))
# After kernel_size timesteps,
# the output should be the same as the torch output
assert np.allclose(output[:, kernel_size:], torch_output)

0 comments on commit 2f3e0f8

Please sign in to comment.