Skip to content

Commit

Permalink
eliminate sum process
Browse files Browse the repository at this point in the history
  • Loading branch information
SveaMeyer13 committed Oct 10, 2022
1 parent 92e908f commit b6b3763
Showing 1 changed file with 14 additions and 51 deletions.
65 changes: 14 additions & 51 deletions tutorials/lava/lib/dnf/dnf_regimes/script_dnf_regimes.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,20 @@
import unittest
import numpy as np
import typing as ty

from lava.lib.dnf.inputs.rate_code_spike_gen.process import RateCodeSpikeGen
from lava.proc.embedded_io.spike import PyToNxAdapter, NxToPyAdapter
from lava.proc.io.sink import RingBuffer
from lava.lib.dnf.connect.connect import connect
from lava.lib.dnf.operations.operations import Convolution
from lava.lib.dnf.inputs.gauss_pattern.process import GaussPattern

from lava.magma.core.run_configs import Loihi1SimCfg
from lava.proc.monitor.process import Monitor
from lava.lib.dnf.kernels.kernels import MultiPeakKernel, Kernel

from utils import plot_1d, animated_1d_plot

from lava.magma.core.decorator import implements, requires
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.process.ports.ports import InPort, OutPort
from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.resources import CPU

from lava.magma.core.run_conditions import RunSteps
from lava.magma.core.run_configs import Loihi2HwCfg
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.proc.lif.process import LIF
from lava.proc.dense.process import Dense


class Sum(AbstractProcess):
def __init__(self, shape: ty.Tuple[int, ...]) -> None:
super().__init__(shape=shape)

self.in_port_1 = InPort(shape=shape)
self.in_port_2 = InPort(shape=shape)
self.out_port = OutPort(shape=shape)


@implements(proc=Sum, protocol=LoihiProtocol)
@requires(CPU)
class SumProcessModel(PyLoihiProcessModel):
in_port_1: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool)
in_port_2: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool)
out_port: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, bool)

# TODO (MR): Fix the bug (?) with multiple Dense Processes that have
# different num_message_bit values.
def run_spk(self) -> None:
input_1 = self.in_port_1.recv()
input_2 = self.in_port_2.recv()
#output = input_1.astype(int) + input_2.astype(int)
output = input_1 + input_2

self.out_port.send(output)


class Architecture:
"""This class structure is not required and is only used here to reduce
code duplication for different examples."""
Expand All @@ -78,11 +36,13 @@ def __init__(self,

self.spike_generator_1 = RateCodeSpikeGen(shape=shape)
self.spike_generator_2 = RateCodeSpikeGen(shape=shape)
self.sum = Sum(shape=shape)

if loihi2:
self.injector = PyToNxAdapter(shape=shape)
self.injector1 = PyToNxAdapter(shape=shape)
self.injector2 = PyToNxAdapter(shape=shape)

self.input_dense = Dense(weights=np.eye(shape[0]) * 25)
self.input_dense1 = Dense(weights=np.eye(shape[0]) * 25)
self.input_dense2 = Dense(weights=np.eye(shape[0]) * 25)

self.dnf = LIF(shape=shape, du=409, dv=2047, vth=200)
dense = connect(self.dnf.s_out, self.dnf.a_in, [Convolution(kernel)])
Expand All @@ -94,15 +54,17 @@ def __init__(self,
self.gauss_pattern_1.a_out.connect(self.spike_generator_1.a_in)
self.gauss_pattern_2.a_out.connect(self.spike_generator_2.a_in)

self.spike_generator_1.s_out.connect(self.sum.in_port_1)
self.spike_generator_2.s_out.connect(self.sum.in_port_2)
if loihi2:
self.sum.out_port.connect(self.injector.inp)
self.injector.out.connect(self.input_dense.s_in)
self.spike_generator_1.s_out.connect(self.injector1.inp)
self.spike_generator_2.s_out.connect(self.injector2.inp)
self.injector1.out.connect(self.input_dense1.s_in)
self.injector2.out.connect(self.input_dense2.s_in)
else:
self.sum.out_port.connect(self.input_dense.s_in)
self.spike_generator_1.s_out.connect(self.input_dense1.s_in)
self.spike_generator_2.s_out.connect(self.input_dense2.s_in)

self.input_dense.a_out.connect(self.dnf.a_in)
self.input_dense1.a_out.connect(self.dnf.a_in)
self.input_dense2.a_out.connect(self.dnf.a_in)

if loihi2:
self.dnf.s_out.connect(self.spike_reader.inp)
Expand Down Expand Up @@ -139,6 +101,7 @@ def run(self):
self.gauss_pattern_1.run(condition=RunSteps(num_steps=200),
run_cfg=self.run_cfg)


def plot(self):
# Get probed data from monitors
data_dnf = self.py_receiver.data.get().transpose()
Expand Down

0 comments on commit b6b3763

Please sign in to comment.