Skip to content

Commit

Permalink
Virtual ports no longer block Process discovery in compiler (#211)
Browse files Browse the repository at this point in the history
* Fixing #210, where virtual ports block Process discovery in the compiler.
* Fixing possible race condition in unit test

Signed-off-by: Mathis Richter <[email protected]>
  • Loading branch information
mathisrichter committed Mar 3, 2022
1 parent 1d1ba2c commit 9622855
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
8 changes: 4 additions & 4 deletions src/lava/magma/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,16 @@ def _find_processes(self,

# add processes connecting to the main process
for in_port in proc.in_ports.members + proc.var_ports.members:
for con in in_port.in_connections:
for con in in_port.get_src_ports():
new_list.append(con.process)
for con in in_port.out_connections:
for con in in_port.get_dst_ports():
new_list.append(con.process)

# add processes connecting from the main process
for out_port in proc.out_ports.members + proc.ref_ports.members:
for con in out_port.in_connections:
for con in out_port.get_src_ports():
new_list.append(con.process)
for con in out_port.out_connections:
for con in out_port.get_dst_ports():
new_list.append(con.process)

for proc in set(new_list):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import functools as ft

from lava.magma.compiler.compiler import Compiler
from lava.magma.core.decorator import requires, tag, implements
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.model.sub.model import AbstractSubProcessModel
Expand Down Expand Up @@ -297,9 +298,40 @@ def test_varport_to_varport_read_in_a_hierarchical_process(self) -> None:
f'{expected[output!=expected] =}\n'
)

def test_compiler_finds_all_processes(self) -> None:
"""Tests whether in Process graphs with virtual ports, all Processes
are found, no matter from which Process the search is started."""

source = OutPortProcess(data=self.input_data)
sink = InPortProcess(shape=self.shape)

virtual_port1 = MockVirtualPort(new_shape=self.new_shape,
axes=self.axes)
virtual_port2 = MockVirtualPort(new_shape=self.shape,
axes=tuple(np.argsort(self.axes)))

source.out_port._connect_forward(
[virtual_port1], AbstractPort, assert_same_shape=False
)
virtual_port1._connect_forward(
[virtual_port2], AbstractPort, assert_same_shape=False
)
virtual_port2.connect(sink.in_port)

compiler = Compiler()
# Test whether all Processes are found when starting the search from
# the source Process
found_procs = compiler._find_processes(source)
expected_procs = [sink, source]
self.assertCountEqual(found_procs, expected_procs)

# Test whether all Processes are found when starting the search from
# the destination Process
found_procs = compiler._find_processes(sink)
self.assertCountEqual(found_procs, expected_procs)

def test_chaining_multiple_virtual_ports(self) -> None:
"""Tests whether two virtual ReshapePorts can be chained through the
flatten() method."""
"""Tests whether virtual ports can be chained."""

source = OutPortProcess(data=self.input_data)
sink = InPortProcess(shape=self.shape)
Expand Down

0 comments on commit 9622855

Please sign in to comment.