diff --git a/src/lava/magma/compiler/subcompilers/channel_map_updater.py b/src/lava/magma/compiler/subcompilers/channel_map_updater.py index f87666d41..64c6756ed 100644 --- a/src/lava/magma/compiler/subcompilers/channel_map_updater.py +++ b/src/lava/magma/compiler/subcompilers/channel_map_updater.py @@ -36,7 +36,10 @@ def add_src_ports(self, src_ports: ty.List[AbstractSrcPort]) -> None: def add_src_port(self, src_port: AbstractSrcPort) -> None: for dst_port in src_port.get_dst_ports(): - self.add_port_pair(src_port, dst_port) + # If the dst_port is still a source port, then it is + # a dangling branch and need not be processed. + if not isinstance(dst_port, AbstractSrcPort): + self.add_port_pair(src_port, dst_port) def add_dst_ports(self, dst_ports: ty.List[AbstractDstPort]) -> None: for dst_port in dst_ports: @@ -44,7 +47,10 @@ def add_dst_ports(self, dst_ports: ty.List[AbstractDstPort]) -> None: def add_dst_port(self, dst_port: AbstractDstPort) -> None: for src_port in dst_port.get_src_ports(): - self.add_port_pair(src_port, dst_port) + # If the src_port is still a destination port, then it is + # a dangling branch and need not be processed. + if not isinstance(src_port, AbstractDstPort): + self.add_port_pair(src_port, dst_port) def add_port_pair(self, src_port: AbstractSrcPort, diff --git a/tests/lava/magma/runtime/test_io_ports.py b/tests/lava/magma/runtime/test_io_ports.py index c7c813b0e..7772671d3 100644 --- a/tests/lava/magma/runtime/test_io_ports.py +++ b/tests/lava/magma/runtime/test_io_ports.py @@ -23,6 +23,7 @@ class HP1(AbstractProcess): def __init__(self, **kwargs): super().__init__(**kwargs) + self.h_var = Var(shape=(2,)) self.h_out = OutPort(shape=(2,)) @@ -38,6 +39,7 @@ def __init__(self, **kwargs): class P1(AbstractProcess): def __init__(self, **kwargs): super().__init__(**kwargs) + self.var = Var(shape=(2,), init=4) self.out = OutPort(shape=(2,)) @@ -60,6 +62,8 @@ def __init__(self, proc): # the nested process self.p1 = P1() self.p1.out.connect(proc.out_ports.h_out) + # Reference h_var with var of the nested process + proc.vars.h_var.alias(self.p1.var) # A minimal hierarchical PyProcModel implementing HP2 @@ -82,11 +86,13 @@ def __init__(self, proc): @requires(CPU) @tag('floating_pt') class PyProcModel1(PyLoihiProcessModel): + var: np.ndarray = LavaPyType(np.ndarray, np.int32) out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int) def run_spk(self): # Send data data = np.array([1, 2]) + self.var = data + 2 self.out.send(data) @@ -252,6 +258,24 @@ def test_branching_hierarchical(self): self.assertTrue(np.all(recv3.h_var.get() == np.array([1, 2]))) sender.stop() + def test_dangling_input(self): + """Checks if the hierarchical process works with dangling input i.e. + input not connected at all.""" + receiver = HP2() + receiver.run(condition=RunSteps(num_steps=2), + run_cfg=Loihi1SimCfg(select_sub_proc_model=True)) + self.assertTrue(np.all(receiver.h_var.get() == np.array([0, 0]))) + receiver.stop() + + def test_dangling_output(self): + """Checks if the hierarchical process works with dangling output i.e. + output not connected at all.""" + sender = HP1() + sender.run(condition=RunSteps(num_steps=2), + run_cfg=Loihi1SimCfg(select_sub_proc_model=True)) + self.assertTrue(np.all(sender.h_var.get() == np.array([3, 4]))) + sender.stop() + if __name__ == '__main__': unittest.main() diff --git a/tutorials/end_to_end/mnist_pretrained.npy b/tutorials/end_to_end/mnist_pretrained.npy index 97d4d8e42..99ec90bec 100644 --- a/tutorials/end_to_end/mnist_pretrained.npy +++ b/tutorials/end_to_end/mnist_pretrained.npy @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a55246798c24fb8122ef7be5bd3b8fbb1222a121d9db502d382c3d5590203555 -size 220771 +oid sha256:94f32a3ae7f8dd278cc8933b214642f246ffd859a129d19130ac88208f35c9d6 +size 220767 diff --git a/tutorials/in_depth/tutorial06_hierarchical_processes.ipynb b/tutorials/in_depth/tutorial06_hierarchical_processes.ipynb index 967cf0825..575c027dd 100644 --- a/tutorials/in_depth/tutorial06_hierarchical_processes.ipynb +++ b/tutorials/in_depth/tutorial06_hierarchical_processes.ipynb @@ -461,13 +461,6 @@ "layer1 = DenseLayer(shape=dim, weights=weights1, bias_mant=4, vth=10)\n", "# Connect the first DenseLayer to the second DenseLayer.\n", "layer0.s_out.connect(layer1.s_in)\n", - "# Instantiate 'plugs' for the dangling InPort and OutPort of the DenseLayers.\n", - "# (This is a work around and will be solved more elegantly later.)\n", - "inport_plug = source.RingBuffer(data=np.zeros((dim[0], 1)))\n", - "outport_plug = sink.RingBuffer(shape=(dim[0],), buffer=1)\n", - "# Connect the 'plugs' to their respective ports.\n", - "inport_plug.s_out.connect(layer0.s_in)\n", - "layer1.s_out.connect(outport_plug.a_in)\n", "\n", "print('Layer 1 weights: \\n', layer1.weights.get(),'\\n')\n", "print('\\n ----- \\n')\n",