Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable RefPort connections #24

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions lava/magma/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(self, compile_cfg: ty.Optional[ty.Dict[str, ty.Any]] = None):
self._compile_config.update(compile_cfg)

# ToDo: (AW) Clean this up by avoiding redundant search paths
# ToDo: (AW) @PP Please include RefPorts/VarPorts in connection tracing
def _find_processes(self,
proc: AbstractProcess,
seen_procs: ty.List[AbstractProcess] = None) \
Expand All @@ -70,14 +69,14 @@ def _find_processes(self,
new_list: ty.List[AbstractProcess] = []

# add processes connecting to the main process
for in_port in proc.in_ports:
for in_port in proc.in_ports.members + proc.var_ports.members:
for con in in_port.in_connections:
new_list.append(con.process)
for con in in_port.out_connections:
new_list.append(con.process)

# add processes connecting from the main process
for out_port in proc.out_ports:
for out_port in proc.out_ports.members + proc.ref_ports.members:
for con in out_port.in_connections:
new_list.append(con.process)
for con in out_port.out_connections:
Expand Down Expand Up @@ -273,7 +272,7 @@ def _compile_proc_models(
v = [VarInitializer(v.name, v.shape, v.init, v.id)
for v in p.vars]
ports = (list(p.in_ports) + list(p.out_ports)
+ list(p.ref_ports))
+ list(p.ref_ports) + list(p.var_ports))
ports = [PortInitializer(pt.name,
pt.shape,
getattr(pm, pt.name).d_type,
Expand Down
65 changes: 63 additions & 2 deletions tests/magma/compiler/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from lava.magma.core.sync.domain import SyncDomain
from lava.magma.core.sync.protocol import AbstractSyncProtocol
from lava.magma.core.sync.protocols.async_protocol import AsyncProtocol
from lava.magma.core.process.ports.ports import InPort, OutPort
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.process.ports.ports import (
InPort, OutPort, RefPort, VarPort)
from lava.magma.core.model.py.ports import PyInPort, PyOutPort, PyRefPort
from lava.magma.core.run_configs import RunConfig
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.process.variable import Var, VarServer
Expand All @@ -29,6 +30,7 @@ def __init__(self, **kwargs):
# Use ReduceOp to allow for multiple input connections
self.inp = InPort(shape=(1,), reduce_op=ReduceSum)
self.out = OutPort(shape=(1,))
self.ref = RefPort(shape=(10,))


# Another minimal process (does not matter that it's identical to ProcA)
Expand All @@ -39,6 +41,7 @@ def __init__(self, **kwargs):
self.inp = InPort(shape=(1,), reduce_op=ReduceSum)
self.out = OutPort(shape=(1,))
self.some_var = Var((10,), init=10)
self.var_port = VarPort(self.some_var)


# Another minimal process (does not matter that it's identical to ProcA)
Expand Down Expand Up @@ -74,6 +77,7 @@ def runtime_service(self):
class PyProcModelA(AbstractPyProcessModel):
inp: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int)
out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int)
ref: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int)

def run(self):
pass
Expand All @@ -86,6 +90,7 @@ class PyProcModelB(AbstractPyProcessModel):
inp: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int)
out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int)
some_var: int = LavaPyType(int, int)
var_port: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int)

def run(self):
pass
Expand Down Expand Up @@ -262,6 +267,32 @@ def test_find_process_circular(self):
self.assertEqual(set(procs4), all_procs)
self.assertEqual(set(procs6), all_procs)

def test_find_process_ref_ports(self):
"""Checks finding all processes for RefPort connection.
[p1 -> ref/var -> p2 -> in/out -> p3]"""

# Create processes
p1, p2, p3 = (
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
ProcA(),
ProcB(),
ProcC()
)
# Create complicated circular structure with joins and forks
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
p1.ref.connect(p2.var_port)
p2.out.connect(p3.inp)

# Regardless where we start searching...
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
c = Compiler()
procs1 = c._find_processes(p1)
procs2 = c._find_processes(p2)
procs3 = c._find_processes(p3)

# ...we will find all of them
all_procs = {p1, p2, p3}
self.assertEqual(set(procs1), all_procs)
self.assertEqual(set(procs2), all_procs)
self.assertEqual(set(procs3), all_procs)

def test_find_proc_models(self):
"""Check finding of ProcModels that implement a Process."""

Expand Down Expand Up @@ -658,6 +689,36 @@ def test_create_channel_builders_hierarchical_process(self):
self.assertEqual(chb[0].src_process, p.procs.proc1)
self.assertEqual(chb[0].dst_process, p.procs.proc2)

def test_create_channel_builders_ref_ports(self):
"""Checks creation of channel builders when a process is connected
using a RefPort to another process."""

# create a process with a RefPort (source)
src = ProcA()

# create a process with a var (destination)
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
dst = ProcB()

# connect them using RefPort and VarPort
src.ref.connect(dst.var_port)

# Create a manual proc_map
proc_map = {
src: PyProcModelA,
dst: PyProcModelB
}

# Create channel builders
c = Compiler()
cbs = c._create_channel_builders(proc_map)

# This should result in 1 channel builder
from lava.magma.compiler.builder import ChannelBuilderMp
self.assertEqual(len(cbs), 1)
self.assertIsInstance(cbs[0], ChannelBuilderMp)
self.assertEqual(cbs[0].src_process, src)
self.assertEqual(cbs[0].dst_process, dst)

# ToDo: (AW) @YS/@JM Please fix unit test by passing run_srv_builders to
# _create_exec_vars when ready
def test_create_py_exec_vars(self):
Expand Down