diff --git a/src/lava/magma/compiler/compiler.py b/src/lava/magma/compiler/compiler.py index 7e3ac47ef..7152d4b8d 100644 --- a/src/lava/magma/compiler/compiler.py +++ b/src/lava/magma/compiler/compiler.py @@ -87,16 +87,16 @@ def _find_processes(self, # add processes connecting to the main process for in_port in proc.in_ports.members + proc.var_ports.members: - for con in in_port.in_connections: + for con in in_port.get_src_ports(): new_list.append(con.process) - for con in in_port.out_connections: + for con in in_port.get_dst_ports(): new_list.append(con.process) # add processes connecting from the main process for out_port in proc.out_ports.members + proc.ref_ports.members: - for con in out_port.in_connections: + for con in out_port.get_src_ports(): new_list.append(con.process) - for con in out_port.out_connections: + for con in out_port.get_dst_ports(): new_list.append(con.process) for proc in set(new_list): diff --git a/tests/lava/magma/core/process/ports/test_virtual_ports_in_process.py b/tests/lava/magma/core/process/ports/test_virtual_ports_in_process.py index deb9cef91..da5c36a09 100644 --- a/tests/lava/magma/core/process/ports/test_virtual_ports_in_process.py +++ b/tests/lava/magma/core/process/ports/test_virtual_ports_in_process.py @@ -7,6 +7,7 @@ import numpy as np import functools as ft +from lava.magma.compiler.compiler import Compiler from lava.magma.core.decorator import requires, tag, implements from lava.magma.core.model.py.model import PyLoihiProcessModel from lava.magma.core.model.sub.model import AbstractSubProcessModel @@ -297,9 +298,40 @@ def test_varport_to_varport_read_in_a_hierarchical_process(self) -> None: f'{expected[output!=expected] =}\n' ) + def test_compiler_finds_all_processes(self) -> None: + """Tests whether in Process graphs with virtual ports, all Processes + are found, no matter from which Process the search is started.""" + + source = OutPortProcess(data=self.input_data) + sink = InPortProcess(shape=self.shape) + + virtual_port1 = MockVirtualPort(new_shape=self.new_shape, + axes=self.axes) + virtual_port2 = MockVirtualPort(new_shape=self.shape, + axes=tuple(np.argsort(self.axes))) + + source.out_port._connect_forward( + [virtual_port1], AbstractPort, assert_same_shape=False + ) + virtual_port1._connect_forward( + [virtual_port2], AbstractPort, assert_same_shape=False + ) + virtual_port2.connect(sink.in_port) + + compiler = Compiler() + # Test whether all Processes are found when starting the search from + # the source Process + found_procs = compiler._find_processes(source) + expected_procs = [sink, source] + self.assertCountEqual(found_procs, expected_procs) + + # Test whether all Processes are found when starting the search from + # the destination Process + found_procs = compiler._find_processes(sink) + self.assertCountEqual(found_procs, expected_procs) + def test_chaining_multiple_virtual_ports(self) -> None: - """Tests whether two virtual ReshapePorts can be chained through the - flatten() method.""" + """Tests whether virtual ports can be chained.""" source = OutPortProcess(data=self.input_data) sink = InPortProcess(shape=self.shape)