Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for dangling ports #274

Merged
merged 18 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/lava/magma/compiler/subcompilers/channel_map_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,21 @@ def add_src_ports(self, src_ports: ty.List[AbstractSrcPort]) -> None:

def add_src_port(self, src_port: AbstractSrcPort) -> None:
for dst_port in src_port.get_dst_ports():
self.add_port_pair(src_port, dst_port)
# If the dst_port is still a source port, then it is
# a dangling branch and need not be processed.
if not isinstance(dst_port, AbstractSrcPort):
self.add_port_pair(src_port, dst_port)

def add_dst_ports(self, dst_ports: ty.List[AbstractDstPort]) -> None:
for dst_port in dst_ports:
self.add_dst_port(dst_port)

def add_dst_port(self, dst_port: AbstractDstPort) -> None:
for src_port in dst_port.get_src_ports():
self.add_port_pair(src_port, dst_port)
# If the src_port is still a destination port, then it is
# a dangling branch and need not be processed.
if not isinstance(src_port, AbstractDstPort):
self.add_port_pair(src_port, dst_port)

def add_port_pair(self,
src_port: AbstractSrcPort,
Expand Down
24 changes: 24 additions & 0 deletions tests/lava/magma/runtime/test_io_ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
class HP1(AbstractProcess):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.h_var = Var(shape=(2,))
self.h_out = OutPort(shape=(2,))


Expand All @@ -38,6 +39,7 @@ def __init__(self, **kwargs):
class P1(AbstractProcess):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.var = Var(shape=(2,), init=4)
self.out = OutPort(shape=(2,))


Expand All @@ -60,6 +62,8 @@ def __init__(self, proc):
# the nested process
self.p1 = P1()
self.p1.out.connect(proc.out_ports.h_out)
# Reference h_var with var of the nested process
bamsumit marked this conversation as resolved.
Show resolved Hide resolved
proc.vars.h_var.alias(self.p1.var)


# A minimal hierarchical PyProcModel implementing HP2
Expand All @@ -82,11 +86,13 @@ def __init__(self, proc):
@requires(CPU)
@tag('floating_pt')
class PyProcModel1(PyLoihiProcessModel):
var: np.ndarray = LavaPyType(np.ndarray, np.int32)
out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int)

def run_spk(self):
# Send data
data = np.array([1, 2])
self.var = data + 2
self.out.send(data)


Expand Down Expand Up @@ -252,6 +258,24 @@ def test_branching_hierarchical(self):
self.assertTrue(np.all(recv3.h_var.get() == np.array([1, 2])))
sender.stop()

def test_dangling_input(self):
"""Checks if the hierarchical process works with dangling input i.e.
input not connected at all."""
receiver = HP2()
receiver.run(condition=RunSteps(num_steps=2),
run_cfg=Loihi1SimCfg(select_sub_proc_model=True))
self.assertTrue(np.all(receiver.h_var.get() == np.array([0, 0])))
receiver.stop()

def test_dangling_output(self):
"""Checks if the hierarchical process works with dangling output i.e.
output not conected at all."""
bamsumit marked this conversation as resolved.
Show resolved Hide resolved
sender = HP1()
sender.run(condition=RunSteps(num_steps=2),
run_cfg=Loihi1SimCfg(select_sub_proc_model=True))
self.assertTrue(np.all(sender.h_var.get() == np.array([3, 4])))
sender.stop()


if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions tutorials/end_to_end/mnist_pretrained.npy
Git LFS file not shown
7 changes: 0 additions & 7 deletions tutorials/in_depth/tutorial06_hierarchical_processes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -461,13 +461,6 @@
"layer1 = DenseLayer(shape=dim, weights=weights1, bias_mant=4, vth=10)\n",
"# Connect the first DenseLayer to the second DenseLayer.\n",
"layer0.s_out.connect(layer1.s_in)\n",
"# Instantiate 'plugs' for the dangling InPort and OutPort of the DenseLayers.\n",
"# (This is a work around and will be solved more elegantly later.)\n",
"inport_plug = source.RingBuffer(data=np.zeros((dim[0], 1)))\n",
"outport_plug = sink.RingBuffer(shape=(dim[0],), buffer=1)\n",
"# Connect the 'plugs' to their respective ports.\n",
"inport_plug.s_out.connect(layer0.s_in)\n",
"layer1.s_out.connect(outport_plug.a_in)\n",
"\n",
"print('Layer 1 weights: \\n', layer1.weights.get(),'\\n')\n",
"print('\\n ----- \\n')\n",
Expand Down