Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the models and process of conv_in_time in src/lava/proc/conv_in_time #833

Merged
merged 22 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e22f9bd
add the models and process of conv_in_time in src/lava/proc/conv_in_time
zeyuliu1037 Jan 31, 2024
1245c3a
remove unused library
zeyuliu1037 Feb 1, 2024
8084ed5
remove Trailing whitespace
zeyuliu1037 Feb 1, 2024
4d4494e
add unittest for conv in time and related pytorch ground truth
zeyuliu1037 Feb 5, 2024
0da012f
add fixed_pt version of conv in time
zeyuliu1037 Feb 5, 2024
0520a93
change input to spike_input
zeyuliu1037 Feb 5, 2024
6435422
Merge branch 'main' into conv_in_time
bamsumit Feb 22, 2024
0902787
add from lava.proc.conv import utils
zeyuliu1037 Feb 22, 2024
1f0d09d
remove unwanted comments
zeyuliu1037 Feb 22, 2024
d40008f
fixed some linting errors
zeyuliu1037 Feb 23, 2024
4424b20
Merge branch 'main' into conv_in_time
PhilippPlank Feb 23, 2024
8f4f147
Start all comments with upper case character & change the year for al…
zeyuliu1037 Feb 23, 2024
4029bb9
remove whitespace
zeyuliu1037 Feb 23, 2024
8df88f4
Merge branch 'main' into conv_in_time
bamsumit Mar 6, 2024
0e1bbde
continuation line under-indented
zeyuliu1037 Mar 6, 2024
de655a0
Merge branch 'conv_in_time' of https://github.com/zeyuliu1037/lava in…
zeyuliu1037 Mar 6, 2024
7ec9957
Trailing whitespace
zeyuliu1037 Mar 6, 2024
045bd58
shorten variables names
zeyuliu1037 Mar 14, 2024
98a89ab
Merge branch 'main' into conv_in_time
bamsumit Mar 29, 2024
d57282a
Merge branch 'main' into conv_in_time
bamsumit Apr 23, 2024
2973973
Merge branch 'main' into conv_in_time
PhilippPlank Apr 24, 2024
d02f4ab
Merge branch 'main' into conv_in_time
PhilippPlank Apr 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions src/lava/proc/conv_in_time/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (C) 2021-22 Intel Corporation
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import numpy as np

from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires, tag
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.proc.conv_in_time.process import ConvInTime
from lava.proc.conv import utils

# self.weights: [num_flat_output_neurons, num_flat_input_neurons]
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
# self.delays: [num_flat_output_neurons, num_flat_input_neurons]
# self.a_buff: [num_flat_output_neurons, max_delay + 1]

class AbstractPyConvInTimeModel(PyLoihiProcessModel):
"""Abstract Conn In Time Process with Dense synaptic connections which incorporates
delays into the Conv Process.
"""
weights: np.ndarray = None
# delays: np.ndarray = None
a_buff: np.ndarray = None
kernel_size: int = None

num_message_bits: np.ndarray = LavaPyType(np.ndarray, np.int8, precision=5)


def calc_act(self, s_in) -> np.ndarray:
"""
Calculate the activation buff by inverse the order in
the kernel. Taking k=3 as an example, the a_buff will be
weights[2] * s_in, weights[1] * s_in, weights[0] * s_in
"""

# sum([K, n_out, n_in] * [n_in,], axis=-1) = [K, n_out] -> [n_out, K]
kernel_size = self.weights.shape[0]
for i in range(kernel_size):
self.a_buff[:,i] += np.sum(self.weights[kernel_size-i-1] * s_in, axis=-1).T

def update_act(self, s_in):
"""
Updates the activations for the connection.
Clears first column of a_buff and rolls them to the last column.
Finally, calculates the activations for the current time step and adds
them to a_buff.
This order of operations ensures that delays of 0 correspond to
the next time step.
"""
self.a_buff[:, 0] = 0
self.a_buff = np.roll(self.a_buff, -1)
self.calc_act(s_in)

def run_spk(self):
# The a_out sent on a each timestep is a buffered value from dendritic
# accumulation at timestep t-1. This prevents deadlocking in
# networks with recurrent connectivity structures.
self.a_out.send(self.a_buff[:, 0])
if self.num_message_bits.item() > 0:
s_in = self.s_in.recv()
else:
s_in = self.s_in.recv().astype(bool)
self.update_act(s_in)


@implements(proc=ConvInTime, protocol=LoihiProtocol)
@requires(CPU)
@tag("floating_pt")
bamsumit marked this conversation as resolved.
Show resolved Hide resolved
class PyConvInTimeFloat(AbstractPyConvInTimeModel):
"""Implementation of Conn In Time Process with Dense synaptic connections in
floating point precision. This short and simple ProcessModel can be used
for quick algorithmic prototyping, without engaging with the nuances of a
fixed point implementation. DelayDense incorporates delays into the Conn
Process.
"""
s_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1)
a_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)
a_buff: np.ndarray = LavaPyType(np.ndarray, float)
# weights is a 3D matrix of form (kernel_size, num_flat_output_neurons,
# num_flat_input_neurons) in C-order (row major).
weights: np.ndarray = LavaPyType(np.ndarray, float)
num_message_bits: np.ndarray = LavaPyType(np.ndarray, int, precision=5)

@implements(proc=ConvInTime, protocol=LoihiProtocol)
@requires(CPU)
@tag("fixed_pt")
class PyConvInTimeFixed(AbstractPyConvInTimeModel):
"""Conv In Time with fixed point synapse implementation."""
s_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1)
a_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)
a_buff: np.ndarray = LavaPyType(np.ndarray, float)
# weights is a 3D matrix of form (kernel_size, num_flat_output_neurons,
# num_flat_input_neurons) in C-order (row major).
weights: np.ndarray = LavaPyType(np.ndarray, float)
num_message_bits: np.ndarray = LavaPyType(np.ndarray, int, precision=5)

def clamp_precision(self, x: np.ndarray) -> np.ndarray:
return utils.signed_clamp(x, bits=24)
124 changes: 124 additions & 0 deletions src/lava/proc/conv_in_time/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (C) 2021-23 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import numpy as np
import typing as ty

from lava.magma.core.process.process import AbstractProcess, LogConfig
from lava.magma.core.process.variable import Var
from lava.magma.core.process.ports.ports import InPort, OutPort


class ConvInTime(AbstractProcess):
"""Connection Process that mimics a convolution of the incoming
events/spikes with a kernel in the time dimension. Realizes the following abstract
behavior: a_out[t] = weights[t-1] * s_in[t-1] + weights[t] * s_in[t] + weights[t+1] * s_in[t+1]

Parameters
----------
weights : numpy.ndarray
3D connection weight matrix of form (kernel_size, num_flat_output_neurons,
num_flat_input_neurons) in C-order (row major).

weight_exp : int, optional
Shared weight exponent of base 2 used to scale magnitude of
weights, if needed. Mostly for fixed point implementations.
Unnecessary for floating point implementations.
Default value is 0.

num_weight_bits : int, optional
Shared weight width/precision used by weight. Mostly for fixed
point implementations. Unnecessary for floating point
implementations.
Default is for weights to use full 8 bit precision.

sign_mode : SignMode, optional
Shared indicator whether synapse is of type SignMode.NULL,
SignMode.MIXED, SignMode.EXCITATORY, or SignMode.INHIBITORY. If
SignMode.MIXED, the sign of the weight is
included in the weight bits and the fixed point weight used for
inference is scaled by 2.
Unnecessary for floating point implementations.

In the fixed point implementation, weights are scaled according to
the following equations:
w_scale = 8 - num_weight_bits + weight_exp + isMixed()
weights = weights * (2 ** w_scale)

num_message_bits : int, optional
Determines whether the Dense Process deals with the incoming
spikes as binary spikes (num_message_bits = 0) or as graded
spikes (num_message_bits > 0). Default is 0.
"""
def __init__(self,
*,
weights: np.ndarray,
name: ty.Optional[str] = None,
num_message_bits: ty.Optional[int] = 0,
log_config: ty.Optional[LogConfig] = None,
**kwargs) -> None:

super().__init__(weights=weights,
num_message_bits=num_message_bits,
name=name,
log_config=log_config,
**kwargs)

self._validate_weights(weights)
shape = weights.shape # [kernel_size, n_flat_output_neurons, n_flat_input_neurons]
# Ports
self.s_in = InPort(shape=(shape[2],))
self.a_out = OutPort(shape=(shape[1],))

# Variables
self.weights = Var(shape=shape, init=weights)
self.a_buff = Var(shape=(shape[1], shape[0]), init=0)
self.num_message_bits = Var(shape=(1,), init=num_message_bits)

@staticmethod
def _validate_weights(weights: np.ndarray) -> None:
if len(np.shape(weights)) != 3:
raise ValueError("Dense Process 'weights' expects a 3D matrix, "
f"got {weights}.")


########## validation code ##########
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
# from lava.proc.conv_in_time.process import ConvInTime
# from lava.proc import io
# import numpy as np
# import torch
# import torch.nn as nn
# np.set_printoptions(linewidth=np.inf)

# num_steps = 10
# n_flat_input_neurons = 2
# n_flat_output_neurons = 3
# kernel_size = 3
# input = np.random.choice([0, 1], size=(n_flat_input_neurons, num_steps))
# sender = io.source.RingBuffer(data=input)
# weights = np.random.rand(kernel_size, n_flat_output_neurons, n_flat_input_neurons)
# conv_in_time = ConvInTime(weights=weights, name='conv_in_time')
# receiver = io.sink.RingBuffer(shape=(n_flat_output_neurons,), buffer=num_steps)

# sender.s_out.connect(conv_in_time.s_in)
# conv_in_time.a_out.connect(receiver.a_in)

# from lava.magma.core.run_conditions import RunSteps
# from lava.magma.core.run_configs import Loihi1SimCfg

# run_condition = RunSteps(num_steps=num_steps)
# run_cfg = Loihi1SimCfg(select_tag="floating_pt")

# conv_in_time.run(condition=run_condition, run_cfg=run_cfg)
# output = receiver.data.get()
# conv_in_time.stop()

# tensor_input = torch.tensor(input, dtype=torch.float32)
# tensor_weights = torch.tensor(weights, dtype=torch.float32)
# conv_layer = nn.Conv1d(in_channels=2, out_channels=3, kernel_size=3, bias=False)
# conv_layer.weight = nn.Parameter(tensor_weights.permute(1, 2, 0))
# torch_output = conv_layer(tensor_input.unsqueeze(0)).squeeze(0).detach().numpy()

# print(' lava output: ', output)
# print('torch output: ', torch_output)
Empty file.
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/lava/proc/conv_in_time/ground_truth/spike_input.npy
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/lava/proc/conv_in_time/ground_truth/torch_output.npy
Git LFS file not shown
67 changes: 67 additions & 0 deletions tests/lava/proc/conv_in_time/test_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (C) 2021-22 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import unittest
import os
import numpy as np
from lava.proc.conv_in_time.process import ConvInTime
from lava.proc import io

from lava.magma.core.run_conditions import RunSteps
from lava.magma.core.run_configs import Loihi1SimCfg

if utils.TORCH_IS_AVAILABLE:
bamsumit marked this conversation as resolved.
Show resolved Hide resolved
import torch
import torch.nn as nn
compare = True
# in this case, the test compares against random torch ground truth
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
else:
compare = False
# in this case, the test compares against saved torch ground truth


class TestConvInTimeProcess(unittest.TestCase):
"""Tests for Conv class"""
def test_init(self) -> None:
"""Tests instantiation of Conv In Time"""
num_steps = 10
kernel_size = 3
n_flat_input_neurons = 2
n_flat_output_neurons = 5
if compare:
spike_input = np.random.choice([0, 1], size=(n_flat_input_neurons, num_steps))
weights = np.random.randint(256, size=[kernel_size, n_flat_output_neurons, n_flat_input_neurons]) - 128
else:
spike_input = np.load(os.path.join(os.path.dirname(__file__), "ground_truth/spike_input.npy"))
weights = np.load(os.path.join(os.path.dirname(__file__), "ground_truth/quantized_weights_k_out_in.npy"))
sender = io.source.RingBuffer(data=spike_input)
conv_in_time = ConvInTime(weights=weights, name='conv_in_time')

receiver = io.sink.RingBuffer(shape=(n_flat_output_neurons,), buffer=num_steps+1)

sender.s_out.connect(conv_in_time.s_in)
conv_in_time.a_out.connect(receiver.a_in)


run_condition = RunSteps(num_steps=num_steps+1)
run_cfg = Loihi1SimCfg(select_tag="floating_pt")

conv_in_time.run(condition=run_condition, run_cfg=run_cfg)
output = receiver.data.get()
conv_in_time.stop()

if compare:
tensor_input = torch.tensor(spike_input, dtype=torch.float32)
tensor_weights = torch.tensor(weights, dtype=torch.float32)
conv_layer = nn.Conv1d(in_channels=n_flat_input_neurons, out_channels=n_flat_output_neurons, kernel_size=kernel_size, bias=False)
# permute the weights to match the torch format
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
conv_layer.weight = nn.Parameter(tensor_weights.permute(1, 2, 0))
torch_output = conv_layer(tensor_input.unsqueeze(0)).squeeze(0).detach().numpy()
else:
torch_output = np.load(os.path.join(os.path.dirname(__file__), "ground_truth/torch_output.npy"))

self.assertEqual(output.shape, (n_flat_output_neurons, num_steps+1))
# after kernel_size timesteps, the output should be the same as the torch output
assert np.allclose(output[:,kernel_size:], torch_output)

Loading