-
Notifications
You must be signed in to change notification settings - Fork 144
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the models and process of conv_in_time in src/lava/proc/conv_in_t…
…ime (#833) * add the models and process of conv_in_time in src/lava/proc/conv_in_time * remove unused library * remove Trailing whitespace * add unittest for conv in time and related pytorch ground truth * add fixed_pt version of conv in time * change input to spike_input * add from lava.proc.conv import utils * remove unwanted comments * fixed some linting errors * Start all comments with upper case character & change the year for all copyright headers to 2024 for new files * remove whitespace * continuation line under-indented * Trailing whitespace * shorten variables names --------- Co-authored-by: bamsumit <[email protected]> Co-authored-by: PhilippPlank <[email protected]>
- Loading branch information
1 parent
fae3ea1
commit 2f3e0f8
Showing
7 changed files
with
274 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
import numpy as np | ||
|
||
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol | ||
from lava.magma.core.model.py.ports import PyInPort, PyOutPort | ||
from lava.magma.core.model.py.type import LavaPyType | ||
from lava.magma.core.resources import CPU | ||
from lava.magma.core.decorator import implements, requires, tag | ||
from lava.magma.core.model.py.model import PyLoihiProcessModel | ||
from lava.proc.conv_in_time.process import ConvInTime | ||
from lava.proc.conv import utils | ||
|
||
|
||
class AbstractPyConvInTimeModel(PyLoihiProcessModel): | ||
"""Abstract Conn In Time Process with Dense synaptic connections | ||
which incorporates delays into the Conv Process. | ||
""" | ||
weights: np.ndarray = None | ||
a_buff: np.ndarray = None | ||
kernel_size: int = None | ||
|
||
num_message_bits: np.ndarray = LavaPyType(np.ndarray, np.int8, precision=5) | ||
|
||
def calc_act(self, s_in) -> np.ndarray: | ||
""" | ||
Calculate the activation buff by inverse the order in | ||
the kernel. Taking k=3 as an example, the a_buff will be | ||
weights[2] * s_in, weights[1] * s_in, weights[0] * s_in | ||
""" | ||
|
||
# The change of the shape is shown below: | ||
# sum([K, n_out, n_in] * [n_in,], axis=-1) = [K, n_out] -> [n_out, K] | ||
kernel_size = self.weights.shape[0] | ||
for i in range(kernel_size): | ||
self.a_buff[:, i] += np.sum( | ||
self.weights[kernel_size - i - 1] * s_in, axis=-1).T | ||
|
||
def update_act(self, s_in): | ||
""" | ||
Updates the activations for the connection. | ||
Clears first column of a_buff and rolls them to the last column. | ||
Finally, calculates the activations for the current time step and adds | ||
them to a_buff. | ||
This order of operations ensures that delays of 0 correspond to | ||
the next time step. | ||
""" | ||
self.a_buff[:, 0] = 0 | ||
self.a_buff = np.roll(self.a_buff, -1) | ||
self.calc_act(s_in) | ||
|
||
def run_spk(self): | ||
# The a_out sent on a each timestep is a buffered value from dendritic | ||
# accumulation at timestep t-1. This prevents deadlocking in | ||
# networks with recurrent connectivity structures. | ||
self.a_out.send(self.a_buff[:, 0]) | ||
if self.num_message_bits.item() > 0: | ||
s_in = self.s_in.recv() | ||
else: | ||
s_in = self.s_in.recv().astype(bool) | ||
self.update_act(s_in) | ||
|
||
|
||
@implements(proc=ConvInTime, protocol=LoihiProtocol) | ||
@requires(CPU) | ||
@tag("floating_pt") | ||
class PyConvInTimeFloat(AbstractPyConvInTimeModel): | ||
"""Implementation of Conn In Time Process with Dense synaptic connections in | ||
floating point precision. This short and simple ProcessModel can be used | ||
for quick algorithmic prototyping, without engaging with the nuances of a | ||
fixed point implementation. DelayDense incorporates delays into the Conn | ||
Process. | ||
""" | ||
s_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1) | ||
a_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float) | ||
a_buff: np.ndarray = LavaPyType(np.ndarray, float) | ||
# The weights is a 3D matrix of form (kernel_size, | ||
# num_flat_output_neurons, num_flat_input_neurons) in C-order (row major). | ||
weights: np.ndarray = LavaPyType(np.ndarray, float) | ||
num_message_bits: np.ndarray = LavaPyType(np.ndarray, int, precision=5) | ||
|
||
|
||
@implements(proc=ConvInTime, protocol=LoihiProtocol) | ||
@requires(CPU) | ||
@tag("fixed_pt") | ||
class PyConvInTimeFixed(AbstractPyConvInTimeModel): | ||
"""Conv In Time with fixed point synapse implementation.""" | ||
s_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1) | ||
a_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float) | ||
a_buff: np.ndarray = LavaPyType(np.ndarray, float) | ||
# The weights is a 3D matrix of form (kernel_size, | ||
# num_flat_output_neurons, num_flat_input_neurons) in C-order (row major). | ||
weights: np.ndarray = LavaPyType(np.ndarray, float) | ||
num_message_bits: np.ndarray = LavaPyType(np.ndarray, int, precision=5) | ||
|
||
def clamp_precision(self, x: np.ndarray) -> np.ndarray: | ||
return utils.signed_clamp(x, bits=24) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
import numpy as np | ||
import typing as ty | ||
|
||
from lava.magma.core.process.process import AbstractProcess, LogConfig | ||
from lava.magma.core.process.variable import Var | ||
from lava.magma.core.process.ports.ports import InPort, OutPort | ||
|
||
|
||
class ConvInTime(AbstractProcess): | ||
"""Connection Process that mimics a convolution of the incoming | ||
events/spikes with a kernel in the time dimension. Realizes the | ||
following abstract behavior: a_out[t] = weights[t-1] * s_in[t-1] | ||
+ weights[t] * s_in[t] + weights[t+1] * s_in[t+1] | ||
Parameters | ||
---------- | ||
weights : numpy.ndarray | ||
3D connection weight matrix of form (kernel_size, | ||
num_flat_output_neurons, num_flat_input_neurons) | ||
in C-order (row major). | ||
weight_exp : int, optional | ||
Shared weight exponent of base 2 used to scale magnitude of | ||
weights, if needed. Mostly for fixed point implementations. | ||
Unnecessary for floating point implementations. | ||
Default value is 0. | ||
num_weight_bits : int, optional | ||
Shared weight width/precision used by weight. Mostly for fixed | ||
point implementations. Unnecessary for floating point | ||
implementations. | ||
Default is for weights to use full 8 bit precision. | ||
sign_mode : SignMode, optional | ||
Shared indicator whether synapse is of type SignMode.NULL, | ||
SignMode.MIXED, SignMode.EXCITATORY, or SignMode.INHIBITORY. If | ||
SignMode.MIXED, the sign of the weight is | ||
included in the weight bits and the fixed point weight used for | ||
inference is scaled by 2. | ||
Unnecessary for floating point implementations. | ||
In the fixed point implementation, weights are scaled according to | ||
the following equations: | ||
w_scale = 8 - num_weight_bits + weight_exp + isMixed() | ||
weights = weights * (2 ** w_scale) | ||
num_message_bits : int, optional | ||
Determines whether the Dense Process deals with the incoming | ||
spikes as binary spikes (num_message_bits = 0) or as graded | ||
spikes (num_message_bits > 0). Default is 0. | ||
""" | ||
def __init__(self, | ||
*, | ||
weights: np.ndarray, | ||
name: ty.Optional[str] = None, | ||
num_message_bits: ty.Optional[int] = 0, | ||
log_config: ty.Optional[LogConfig] = None, | ||
**kwargs) -> None: | ||
|
||
super().__init__(weights=weights, | ||
num_message_bits=num_message_bits, | ||
name=name, | ||
log_config=log_config, | ||
**kwargs) | ||
|
||
self._validate_weights(weights) | ||
# [kernel_size, n_flat_output_neurons, n_flat_input_neurons] | ||
shape = weights.shape | ||
# Ports | ||
self.s_in = InPort(shape=(shape[2],)) | ||
self.a_out = OutPort(shape=(shape[1],)) | ||
|
||
# Variables | ||
self.weights = Var(shape=shape, init=weights) | ||
self.a_buff = Var(shape=(shape[1], shape[0]), init=0) | ||
self.num_message_bits = Var(shape=(1,), init=num_message_bits) | ||
|
||
@staticmethod | ||
def _validate_weights(weights: np.ndarray) -> None: | ||
if len(np.shape(weights)) != 3: | ||
raise ValueError("Dense Process 'weights' expects a 3D matrix, " | ||
f"got {weights}.") |
Empty file.
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
import unittest | ||
import os | ||
import numpy as np | ||
from lava.proc.conv_in_time.process import ConvInTime | ||
from lava.proc import io | ||
|
||
from lava.magma.core.run_conditions import RunSteps | ||
from lava.magma.core.run_configs import Loihi1SimCfg | ||
from lava.proc.conv import utils | ||
|
||
if utils.TORCH_IS_AVAILABLE: | ||
import torch | ||
import torch.nn as nn | ||
compare = True | ||
# In this case, the test compares against random torch ground truth | ||
else: | ||
compare = False | ||
# In this case, the test compares against saved torch ground truth | ||
|
||
|
||
class TestConvInTimeProcess(unittest.TestCase): | ||
"""Tests for Conv class""" | ||
def test_init(self) -> None: | ||
"""Tests instantiation of Conv In Time""" | ||
num_steps = 10 | ||
kernel_size = 3 | ||
n_in = 2 | ||
n_out = 5 | ||
if compare: | ||
spike_input = np.random.choice( | ||
[0, 1], | ||
size=(n_in, num_steps)) | ||
weights = np.random.randint(256, size=[kernel_size, | ||
n_out, | ||
n_in]) - 128 | ||
else: | ||
spike_input = np.load(os.path.join(os.path.dirname(__file__), | ||
"gts/spike_input.npy")) | ||
weights = np.load(os.path.join(os.path.dirname(__file__), | ||
"gts/q_weights.npy")) | ||
sender = io.source.RingBuffer(data=spike_input) | ||
conv_in_time = ConvInTime(weights=weights, name='conv_in_time') | ||
|
||
receiver = io.sink.RingBuffer( | ||
shape=(n_out,), | ||
buffer=num_steps + 1) | ||
|
||
sender.s_out.connect(conv_in_time.s_in) | ||
conv_in_time.a_out.connect(receiver.a_in) | ||
|
||
run_condition = RunSteps(num_steps=num_steps + 1) | ||
run_cfg = Loihi1SimCfg(select_tag="floating_pt") | ||
|
||
conv_in_time.run(condition=run_condition, run_cfg=run_cfg) | ||
output = receiver.data.get() | ||
conv_in_time.stop() | ||
|
||
if compare: | ||
tensor_input = torch.tensor(spike_input, dtype=torch.float32) | ||
tensor_weights = torch.tensor(weights, dtype=torch.float32) | ||
conv_layer = nn.Conv1d( | ||
in_channels=n_in, | ||
out_channels=n_out, | ||
kernel_size=kernel_size, bias=False) | ||
# Permute the weights to match the torch format | ||
conv_layer.weight = nn.Parameter(tensor_weights.permute(1, 2, 0)) | ||
torch_output = conv_layer( | ||
tensor_input.unsqueeze(0)).squeeze(0).detach().numpy() | ||
else: | ||
torch_output = np.load(os.path.join(os.path.dirname(__file__), | ||
"gts/torch_output.npy")) | ||
|
||
self.assertEqual(output.shape, (n_out, num_steps + 1)) | ||
# After kernel_size timesteps, | ||
# the output should be the same as the torch output | ||
assert np.allclose(output[:, kernel_size:], torch_output) |