Skip to content

Commit

Permalink
Fix for dangling ports (lava-nc#274)
Browse files Browse the repository at this point in the history
* update refport unittest to always wait when it writes to port for consistent behavior

Signed-off-by: bamsumit <[email protected]>

* Removed pyproject changes

Signed-off-by: bamsumit <[email protected]>

* Fix to convolution tests. Fixed imcompatible mnist_pretrained for old python versions.

Signed-off-by: bamsumit <[email protected]>

* Dangling port fix update

Signed-off-by: bamsumit <[email protected]>

* Fix dangling port without breaking channel map unittests

Signed-off-by: bamsumit <[email protected]>

* Remove unnecessary hack for dangling port in tutorial06

Signed-off-by: bamsumit <[email protected]>

* Added dangling port unittests

Signed-off-by: bamsumit <[email protected]>

* Linting fix

Signed-off-by: bamsumit <[email protected]>

* Typo fixes

Signed-off-by: bamsumit <[email protected]>
  • Loading branch information
bamsumit committed Jul 22, 2022
1 parent 8b38336 commit 0658454
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 11 deletions.
10 changes: 8 additions & 2 deletions src/lava/magma/compiler/subcompilers/channel_map_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,21 @@ def add_src_ports(self, src_ports: ty.List[AbstractSrcPort]) -> None:

def add_src_port(self, src_port: AbstractSrcPort) -> None:
for dst_port in src_port.get_dst_ports():
self.add_port_pair(src_port, dst_port)
# If the dst_port is still a source port, then it is
# a dangling branch and need not be processed.
if not isinstance(dst_port, AbstractSrcPort):
self.add_port_pair(src_port, dst_port)

def add_dst_ports(self, dst_ports: ty.List[AbstractDstPort]) -> None:
for dst_port in dst_ports:
self.add_dst_port(dst_port)

def add_dst_port(self, dst_port: AbstractDstPort) -> None:
for src_port in dst_port.get_src_ports():
self.add_port_pair(src_port, dst_port)
# If the src_port is still a destination port, then it is
# a dangling branch and need not be processed.
if not isinstance(src_port, AbstractDstPort):
self.add_port_pair(src_port, dst_port)

def add_port_pair(self,
src_port: AbstractSrcPort,
Expand Down
24 changes: 24 additions & 0 deletions tests/lava/magma/runtime/test_io_ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
class HP1(AbstractProcess):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.h_var = Var(shape=(2,))
self.h_out = OutPort(shape=(2,))


Expand All @@ -38,6 +39,7 @@ def __init__(self, **kwargs):
class P1(AbstractProcess):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.var = Var(shape=(2,), init=4)
self.out = OutPort(shape=(2,))


Expand All @@ -60,6 +62,8 @@ def __init__(self, proc):
# the nested process
self.p1 = P1()
self.p1.out.connect(proc.out_ports.h_out)
# Reference h_var with var of the nested process
proc.vars.h_var.alias(self.p1.var)


# A minimal hierarchical PyProcModel implementing HP2
Expand All @@ -82,11 +86,13 @@ def __init__(self, proc):
@requires(CPU)
@tag('floating_pt')
class PyProcModel1(PyLoihiProcessModel):
var: np.ndarray = LavaPyType(np.ndarray, np.int32)
out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int)

def run_spk(self):
# Send data
data = np.array([1, 2])
self.var = data + 2
self.out.send(data)


Expand Down Expand Up @@ -252,6 +258,24 @@ def test_branching_hierarchical(self):
self.assertTrue(np.all(recv3.h_var.get() == np.array([1, 2])))
sender.stop()

def test_dangling_input(self):
"""Checks if the hierarchical process works with dangling input i.e.
input not connected at all."""
receiver = HP2()
receiver.run(condition=RunSteps(num_steps=2),
run_cfg=Loihi1SimCfg(select_sub_proc_model=True))
self.assertTrue(np.all(receiver.h_var.get() == np.array([0, 0])))
receiver.stop()

def test_dangling_output(self):
"""Checks if the hierarchical process works with dangling output i.e.
output not connected at all."""
sender = HP1()
sender.run(condition=RunSteps(num_steps=2),
run_cfg=Loihi1SimCfg(select_sub_proc_model=True))
self.assertTrue(np.all(sender.h_var.get() == np.array([3, 4])))
sender.stop()


if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions tutorials/end_to_end/mnist_pretrained.npy
Git LFS file not shown
7 changes: 0 additions & 7 deletions tutorials/in_depth/tutorial06_hierarchical_processes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -461,13 +461,6 @@
"layer1 = DenseLayer(shape=dim, weights=weights1, bias_mant=4, vth=10)\n",
"# Connect the first DenseLayer to the second DenseLayer.\n",
"layer0.s_out.connect(layer1.s_in)\n",
"# Instantiate 'plugs' for the dangling InPort and OutPort of the DenseLayers.\n",
"# (This is a work around and will be solved more elegantly later.)\n",
"inport_plug = source.RingBuffer(data=np.zeros((dim[0], 1)))\n",
"outport_plug = sink.RingBuffer(shape=(dim[0],), buffer=1)\n",
"# Connect the 'plugs' to their respective ports.\n",
"inport_plug.s_out.connect(layer0.s_in)\n",
"layer1.s_out.connect(outport_plug.a_in)\n",
"\n",
"print('Layer 1 weights: \\n', layer1.weights.get(),'\\n')\n",
"print('\\n ----- \\n')\n",
Expand Down

0 comments on commit 0658454

Please sign in to comment.