Skip to content

Commit

Permalink
Fix/learning lif (#536)
Browse files Browse the repository at this point in the history
* add learning lif float and tests

* lint

* changing tutorial to- inherit RSTDP from LearingLIF

Co-authored-by: Philipp <[email protected]>
  • Loading branch information
bala-git9 and weidel-p authored Dec 12, 2022
1 parent 88f419c commit 79e4202
Showing 1 changed file with 16 additions and 66 deletions.
82 changes: 16 additions & 66 deletions tutorials/in_depth/three_factor_learning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing as ty
import numpy as np

from lava.proc.lif.process import LIF, AbstractLIF, LogConfig
from lava.proc.lif.process import LIF, AbstractLIF, LogConfig, LearningLIF
from lava.proc.io.source import RingBuffer
from lava.proc.dense.process import LearningDense, Dense
from lava.magma.core.process.neuron import LearningNeuronProcess
Expand All @@ -26,79 +26,28 @@
)


class RSTDPLIF(LearningNeuronProcess, AbstractLIF):
"""Leaky-Integrate-and-Fire (LIF) neural Process with RSTDP learning rule.
Parameters
----------
shape : tuple(int)
Number and topology of LIF neurons.
u : float, list, numpy.ndarray, optional
Initial value of the neurons' current.
v : float, list, numpy.ndarray, optional
Initial value of the neurons' voltage (membrane potential).
du : float, optional
Inverse of decay time-constant for current decay. Currently, only a
single decay can be set for the entire population of neurons.
dv : float, optional
Inverse of decay time-constant for voltage decay. Currently, only a
single decay can be set for the entire population of neurons.
bias_mant : float, list, numpy.ndarray, optional
Mantissa part of neuron bias.
bias_exp : float, list, numpy.ndarray, optional
Exponent part of neuron bias, if needed. Mostly for fixed point
implementations. Ignored for floating point implementations.
vth : float, optional
Neuron threshold voltage, exceeding which, the neuron will spike.
Currently, only a single threshold can be set for the entire
population of neurons.
"""
def __init__(
self,
*,
shape: ty.Tuple[int, ...],
u: ty.Optional[ty.Union[float, list, np.ndarray]] = 0,
v: ty.Optional[ty.Union[float, list, np.ndarray]] = 0,
du: ty.Optional[float] = 0,
dv: ty.Optional[float] = 0,
bias_mant: ty.Optional[ty.Union[float, list, np.ndarray]] = 0,
bias_exp: ty.Optional[ty.Union[float, list, np.ndarray]] = 0,
vth: ty.Optional[float] = 10,
name: ty.Optional[str] = None,
log_config: ty.Optional[LogConfig] = None,
**kwargs) -> None:
super().__init__(shape=shape, u=u, v=v, du=du, dv=dv,
bias_mant=bias_mant,
bias_exp=bias_exp, name=name,
log_config=log_config,
**kwargs)
self.vth = Var(shape=(1,), init=vth)

self.a_third_factor_in = InPort(shape=shape)
class RSTDPLIF(LearningLIF):
pass


@implements(proc=RSTDPLIF, protocol=LoihiProtocol)
@requires(CPU)
@tag('floating_pt')
@tag("floating_pt")
class RSTDPLIFModelFloat(LearningNeuronModelFloat, AbstractPyLifModelFloat):
"""Implementation of Leaky-Integrate-and-Fire neural
process in floating point precision with learning enabled
to do R-STDP.
"""
# Graded reward input spikes
a_third_factor_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, float)

s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)
vth: float = LavaPyType(float, float)

def __init__(self, proc_params):
super().__init__(proc_params)
self.s_out_buff = np.zeros(proc_params['shape'])
self.s_out_buff = np.zeros(proc_params["shape"])

def spiking_activation(self):
"""Spiking activation function for Learning LIF.
"""
"""Spiking activation function for Learning LIF."""
return self.v > self.vth

def calculate_third_factor_trace(self, s_graded_in: float) -> float:
Expand Down Expand Up @@ -133,11 +82,13 @@ def run_spk(self) -> None:
s_out_y1: sends the post-synaptic spike times.
s_out_y2: sends the graded third-factor reward signal.
"""

self.y1 = self.compute_post_synaptic_trace(self.s_out_buff)

super().run_spk()

a_graded_in = self.a_third_factor_in.recv()

self.y1 = self.compute_post_synaptic_trace(self.s_out_buff)
self.y2 = self.calculate_third_factor_trace(a_graded_in)

self.s_out_bap.send(self.s_out_buff)
Expand All @@ -148,7 +99,7 @@ def run_spk(self) -> None:

@implements(proc=RSTDPLIF, protocol=LoihiProtocol)
@requires(CPU)
@tag('bit_accurate_loihi', 'fixed_pt')
@tag("bit_accurate_loihi", "fixed_pt")
class RSTDPLIFBitAcc(LearningNeuronModelFixed, AbstractPyLifModelFixed):
"""Implementation of RSTDP Leaky-Integrate-and-Fire neural
process bit-accurate with Loihi's hardware LIF dynamics,
Expand All @@ -169,16 +120,14 @@ class RSTDPLIFBitAcc(LearningNeuronModelFixed, AbstractPyLifModelFixed):
- vth: unsigned 17-bit integer (0 to 131071).
"""
# Graded reward input spikes
a_third_factor_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, float)

s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=24)
vth: int = LavaPyType(int, np.int32, precision=17)

def __init__(self, proc_params):
super().__init__(proc_params)
self.effective_vth = 0
self.s_out_buff = np.zeros(proc_params['shape'])
self.s_out_buff = np.zeros(proc_params["shape"])

def scale_threshold(self):
"""Scale threshold according to the way Loihi hardware scales it. In
Expand All @@ -189,10 +138,9 @@ def scale_threshold(self):
self.isthrscaled = True

def spiking_activation(self):
"""Spike when voltage exceeds threshold.
"""
"""Spike when voltage exceeds threshold."""
return self.v > self.effective_vth

def calculate_third_factor_trace(self, s_graded_in: float) -> float:
"""Generate's a third factor reward traces based on
graded input spikes to the Learning LIF process.
Expand Down Expand Up @@ -225,11 +173,13 @@ def run_spk(self) -> None:
s_out_y1: sends the post-synaptic spike times.
s_out_y2: sends the graded third-factor reward signal.
"""

self.y1 = self.compute_post_synaptic_trace(self.s_out_buff)

super().run_spk()

a_graded_in = self.a_third_factor_in.recv()

self.y1 = self.compute_post_synaptic_trace(self.s_out_buff)
self.y2 = self.calculate_third_factor_trace(a_graded_in)

self.s_out_bap.send(self.s_out_buff)
Expand Down

0 comments on commit 79e4202

Please sign in to comment.