Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TernaryLIF and refactoring of LIF to inherit from AbstractLIF #151

Merged
merged 5 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 213 additions & 59 deletions src/lava/proc/lif/models.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,75 @@
# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import numpy as np
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires, tag
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.proc.lif.process import LIF
from lava.proc.lif.process import LIF, TernaryLIF


@implements(proc=LIF, protocol=LoihiProtocol)
@requires(CPU)
@tag('floating_pt')
class PyLifModelFloat(PyLoihiProcessModel):
"""Implementation of Leaky-Integrate-and-Fire neural process in floating
point precision. This short and simple ProcessModel can be used for quick
algorithmic prototyping, without engaging with the nuances of a fixed
point implementation.
class AbstractPyLifModelFloat(PyLoihiProcessModel):
"""Abstract implementation of floating point precision
leaky-integrate-and-fire neuron model. Specific implementations
inherit from here.
srrisbud marked this conversation as resolved.
Show resolved Hide resolved
"""
a_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, float)
s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, bool, precision=1)
s_out = None # This will be an OutPort of different LavaPyTypes
u: np.ndarray = LavaPyType(np.ndarray, float)
v: np.ndarray = LavaPyType(np.ndarray, float)
bias: np.ndarray = LavaPyType(np.ndarray, float)
bias_exp: np.ndarray = LavaPyType(np.ndarray, float)
du: float = LavaPyType(float, float)
dv: float = LavaPyType(float, float)
vth: float = LavaPyType(float, float)

def run_spk(self):
a_in_data = self.a_in.recv()
def spiking_activation(self):
"""Abstract method to define the activation function that determines
how spikes are generated"""
raise NotImplementedError("spiking activation() cannot be called from "
"an abstract ProcessModel")

def subthr_dynamics(self, activation_in: np.ndarray):
"""Common sub-threshold dynamics of current and voltage variables for
all LIF models. This is where the 'leaky integration' happens."""
self.u[:] = self.u * (1 - self.du)
self.u[:] += a_in_data
bias = self.bias * (2**self.bias_exp)
self.u[:] += activation_in
bias = self.bias * (2 ** self.bias_exp)
srrisbud marked this conversation as resolved.
Show resolved Hide resolved
self.v[:] = self.v * (1 - self.dv) + self.u + bias
s_out = self.v >= self.vth
self.v[s_out] = 0 # Reset voltage to 0
self.s_out.send(s_out)

def reset_voltage(self, spike_vector: np.ndarray):
"""Voltage reset behaviour. This can differ for different neuron
models."""
self.v[spike_vector] = 0

@implements(proc=LIF, protocol=LoihiProtocol)
@requires(CPU)
@tag('bit_accurate_loihi', 'fixed_pt')
class PyLifModelBitAcc(PyLoihiProcessModel):
"""Implementation of Leaky-Integrate-and-Fire neural process bit-accurate
with Loihi's hardware LIF dynamics, which means, it mimics Loihi
behaviour bit-by-bit.
def run_spk(self):
"""The run function that performs the actual computation during
execution orchestrated by a PyLoihiProcessModel using the
LoihiProtocol"""
a_in_data = self.a_in.recv()
self.subthr_dynamics(activation_in=a_in_data)
s_out = self.spiking_activation()
self.reset_voltage(spike_vector=s_out)
self.s_out.send(s_out)

srrisbud marked this conversation as resolved.
Show resolved Hide resolved
Currently missing features (compared to Loihi 1 hardware):
- refractory period after spiking
- axonal delays

Precisions of state variables
-----------------------------
du: unsigned 12-bit integer (0 to 4095)
dv: unsigned 12-bit integer (0 to 4095)
bias: signed 13-bit integer (-4096 to 4095). Mantissa part of neuron bias.
bias_exp: unsigned 3-bit integer (0 to 7). Exponent part of neuron bias.
vth: unsigned 17-bit integer (0 to 131071)
"""
class AbstractPyLifModelFixed(PyLoihiProcessModel):
"""Abstract implementation of fixed point precision
leaky-integrate-and-fire neuron model. Implementations like those
bit-accurate with Loihi hardware inherit from here."""
a_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, np.int16, precision=16)
s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, bool, precision=1)
s_out: None # This will be an OutPort of different LavaPyTypes
u: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)
v: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)
du: int = LavaPyType(int, np.uint16, precision=12)
dv: int = LavaPyType(int, np.uint16, precision=12)
bias: np.ndarray = LavaPyType(np.ndarray, np.int16, precision=13)
bias_exp: np.ndarray = LavaPyType(np.ndarray, np.int16, precision=3)
vth: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=17)

def __init__(self):
super(PyLifModelBitAcc, self).__init__()
super(AbstractPyLifModelFixed, self).__init__()
# ds_offset and dm_offset are 1-bit registers in Loihi 1, which are
# added to du and dv variables to compute effective decay constants
# for current and voltage, respectively. They enable setting decay
Expand All @@ -83,9 +79,9 @@ def __init__(self):
# outside, but this will change in the future.
srrisbud marked this conversation as resolved.
Show resolved Hide resolved
self.ds_offset = 1
self.dm_offset = 0
self.b_vth_computed = False
self.isbiasscaled = False
self.isthrscaled = False
self.effective_bias = 0
self.effective_vth = 0
# Let's define some bit-widths from Loihi
# State variables u and v are 24-bits wide
self.uv_bitwidth = 24
Expand All @@ -97,17 +93,29 @@ def __init__(self):
self.vth_shift = 6
self.act_shift = 6

def run_spk(self):
# Receive synaptic input
a_in_data = self.a_in.recv()
def scale_bias(self):
"""Scale bias with bias exponent by taking into account sign of the
exponent"""
self.effective_bias = np.where(self.bias_exp >= 0, np.left_shift(
self.bias, self.bias_exp), np.right_shift(self.bias,
-self.bias_exp))
# self.effective_bias = np.left_shift(self.bias, self.bias_exp) if \
# self.bias_exp >= 0 else np.right_shift(self.bias, -self.bias_exp)
srrisbud marked this conversation as resolved.
Show resolved Hide resolved
self.isbiasscaled = True
srrisbud marked this conversation as resolved.
Show resolved Hide resolved

# Compute effective bias and threshold only once, not every time-step
if not self.b_vth_computed:
self.effective_bias = np.left_shift(self.bias, self.bias_exp)
# In Loihi, user specified threshold is just the mantissa, with a
# constant exponent of 6
self.effective_vth = np.left_shift(self.vth, self.vth_shift)
self.b_vth_computed = True
def scale_threshold(self):
"""Placeholder method for scaling threshold(s)."""
raise NotImplementedError("spiking activation() cannot be called from "
"an abstract ProcessModel")

def spiking_activation(self):
"""Placeholder method to specify spiking behaviour of a LIF neuron."""
raise NotImplementedError("spiking activation() cannot be called from "
"an abstract ProcessModel")

def subthr_dynamics(self, activation_in: np.ndarray):
"""Common sub-threshold dynamics of current and voltage variables for
all LIF models. This is where the 'leaky integration' happens."""

# Update current
# --------------
Expand All @@ -120,14 +128,18 @@ def run_spk(self):
decayed_curr), self.decay_shift)
decayed_curr = np.int32(decayed_curr)
# Hardware left-shifts synpatic input for MSB alignment
a_in_data = np.left_shift(a_in_data, self.act_shift)
activation_in = np.left_shift(activation_in, self.act_shift)
# Add synptic input to decayed current
decayed_curr += a_in_data
decayed_curr += activation_in
# Check if value of current is within bounds of 24-bit. Overflows are
# handled by wrapping around modulo 2 ** 23. E.g., (2 ** 23) + k
# becomes k and -(2**23 + k) becomes -k
srrisbud marked this conversation as resolved.
Show resolved Hide resolved
sign_of_curr = np.sign(decayed_curr)
# when decayed_curr is 0, we don't care about its sign. But np.mod
# needs something non-zero to avoid the divide-by-zero warning
sign_of_curr[sign_of_curr == 0] = 1
wrapped_curr = np.mod(decayed_curr,
np.sign(decayed_curr) * self.max_uv_val)
sign_of_curr * self.max_uv_val)
self.u[:] = wrapped_curr
# Update voltage
# --------------
Expand All @@ -145,9 +157,151 @@ def run_spk(self):
updated_volt = decayed_volt + self.u + self.effective_bias
self.v[:] = np.clip(updated_volt, neg_voltage_limit, pos_voltage_limit)

# Spike when exceeds threshold
# ----------------------------
s_out = self.v >= self.effective_vth
def reset_voltage(self, spike_vector: np.ndarray):
"""Voltage reset behaviour. This can differ for different neuron
models.
"""
self.v[spike_vector] = 0

def run_spk(self):
"""The run function that performs the actual computation during
execution orchestrated by a PyLoihiProcessModel using the
LoihiProtocol
"""
# Receive synaptic input
a_in_data = self.a_in.recv()

# Compute effective bias and threshold only once, not every time-step
if not self.isbiasscaled:
self.scale_bias()

if not self.isthrscaled:
self.scale_threshold()
srrisbud marked this conversation as resolved.
Show resolved Hide resolved

self.subthr_dynamics(activation_in=a_in_data)

s_out = self.spiking_activation()

# Reset voltage of spiked neurons to 0
self.v[s_out] = 0
self.reset_voltage(spike_vector=s_out)
self.s_out.send(s_out)


@implements(proc=LIF, protocol=LoihiProtocol)
@requires(CPU)
@tag('floating_pt')
class PyLifModelFloat(AbstractPyLifModelFloat):
"""Implementation of Leaky-Integrate-and-Fire neural process in floating
point precision. This short and simple ProcessModel can be used for quick
algorithmic prototyping, without engaging with the nuances of a fixed
point implementation.
"""
s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, bool, precision=1)
vth: float = LavaPyType(float, float)

def spiking_activation(self):
"""Spiking activation function for LIF"""
return self.v >= self.vth
srrisbud marked this conversation as resolved.
Show resolved Hide resolved


@implements(proc=TernaryLIF, protocol=LoihiProtocol)
@requires(CPU)
@tag('floating_pt')
class PyTernLifModelFloat(AbstractPyLifModelFloat):
"""Implementation of Ternary Leaky-Integrate-and-Fire neural process in
floating point precision. This ProcessModel builds upon the floating
point ProcessModel for LIF by adding upper and lower threshold voltages.
"""
# Spikes now become 2-bit signed floating point numbers
s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float, precision=2)
vth_hi: float = LavaPyType(float, float)
vth_lo: float = LavaPyType(float, float)

def spiking_activation(self):
"""Spiking activation for T-LIF: -1 spikes below lower threshold,
+1 spikes above upper threshold"""
return (-1) * (self.v <= self.vth_lo) + (self.v >= self.vth_hi)

def reset_voltage(self, spike_vector: np.ndarray):
"""Reset voltage of all spiking neurons to 0"""
self.v[spike_vector != 0] = 0 # Reset voltage to 0 wherever we spiked


@implements(proc=LIF, protocol=LoihiProtocol)
@requires(CPU)
@tag('bit_accurate_loihi', 'fixed_pt')
class PyLifModelBitAcc(AbstractPyLifModelFixed):
"""Implementation of Leaky-Integrate-and-Fire neural process bit-accurate
with Loihi's hardware LIF dynamics, which means, it mimics Loihi
behaviour bit-by-bit.

Currently missing features (compared to Loihi 1 hardware):
- refractory period after spiking
- axonal delays

Precisions of state variables
-----------------------------
du: unsigned 12-bit integer (0 to 4095)
srrisbud marked this conversation as resolved.
Show resolved Hide resolved
dv: unsigned 12-bit integer (0 to 4095)
bias: signed 13-bit integer (-4096 to 4095). Mantissa part of neuron bias.
bias_exp: unsigned 3-bit integer (0 to 7). Exponent part of neuron bias.
vth: unsigned 17-bit integer (0 to 131071)
"""
s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, bool, precision=1)
vth: int = LavaPyType(int, np.int32, precision=17)

def __init__(self):
super(PyLifModelBitAcc, self).__init__()
self.effective_vth = 0

def scale_threshold(self):
"""Scale threshold according to the way Loihi hardware scales it. In
Loihi hardware, threshold is left-shifted by 6-bits to MSB-align it
with other state variables of higher precision.
"""
self.effective_vth = np.left_shift(self.vth, self.vth_shift)
self.isthrscaled = True

def spiking_activation(self):
"""Spike when voltage exceeds threshold"""
srrisbud marked this conversation as resolved.
Show resolved Hide resolved
return self.v >= self.effective_vth


@implements(proc=TernaryLIF, protocol=LoihiProtocol)
@requires(CPU)
@tag('fixed_pt')
class PyTernLifModelFixed(AbstractPyLifModelFixed):
"""Implementation of Ternary Leaky-Integrate-and-Fire neural process
with fixed point precision. Some state variables are bit-accurate
srrisbud marked this conversation as resolved.
Show resolved Hide resolved
with Loihi's hardware LIF dynamics, which means, they follow Loihi
behaviour bit-by-bit.

See Also
--------
lava.proc.lif.models.PyLifModelBitAcc: Bit-Accurate LIF neuron model
"""
# Spikes now become 2-bit signed integers
s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int, precision=2)
vth_hi: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)
vth_lo: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)

def __init__(self):
super(PyTernLifModelFixed, self).__init__()
self.effective_vth_hi = 0
self.effective_vth_lo = 0

def scale_threshold(self):
self.effective_vth_hi = np.left_shift(self.vth_hi, self.vth_shift)
self.effective_vth_lo = np.left_shift(self.vth_lo, self.vth_shift)
self.isthrscaled = True

def spiking_activation(self):
# Spike when exceeds threshold
# ----------------------------
neg_spikes = self.v <= self.effective_vth_lo
pos_spikes = self.v >= self.effective_vth_hi
return (-1) * neg_spikes + pos_spikes

def reset_voltage(self, spike_vector: np.ndarray):
"""Reset voltage of all spiking neurons to 0"""
self.v[spike_vector != 0] = 0 # Reset voltage to 0 wherever we spiked
Loading