diff --git a/lava/magma/compiler/compiler.py b/lava/magma/compiler/compiler.py index 1da5e178c..0f47dedb7 100644 --- a/lava/magma/compiler/compiler.py +++ b/lava/magma/compiler/compiler.py @@ -32,6 +32,7 @@ from lava.magma.core.model.nc.model import AbstractNcProcessModel from lava.magma.core.model.py.model import AbstractPyProcessModel from lava.magma.core.model.sub.model import AbstractSubProcessModel +from lava.magma.core.process.ports.ports import AbstractPort, VarPort from lava.magma.core.process.process import AbstractProcess from lava.magma.core.resources import CPU, NeuroCore from lava.magma.core.run_configs import RunConfig @@ -50,7 +51,6 @@ def __init__(self, compile_cfg: ty.Optional[ty.Dict[str, ty.Any]] = None): self._compile_config.update(compile_cfg) # ToDo: (AW) Clean this up by avoiding redundant search paths - # ToDo: (AW) @PP Please include RefPorts/VarPorts in connection tracing def _find_processes(self, proc: AbstractProcess, seen_procs: ty.List[AbstractProcess] = None) \ @@ -70,14 +70,14 @@ def _find_processes(self, new_list: ty.List[AbstractProcess] = [] # add processes connecting to the main process - for in_port in proc.in_ports: + for in_port in proc.in_ports.members + proc.var_ports.members: for con in in_port.in_connections: new_list.append(con.process) for con in in_port.out_connections: new_list.append(con.process) # add processes connecting from the main process - for out_port in proc.out_ports: + for out_port in proc.out_ports.members + proc.ref_ports.members: for con in out_port.in_connections: new_list.append(con.process) for con in out_port.out_connections: @@ -273,10 +273,10 @@ def _compile_proc_models( v = [VarInitializer(v.name, v.shape, v.init, v.id) for v in p.vars] ports = (list(p.in_ports) + list(p.out_ports) - + list(p.ref_ports)) + + list(p.ref_ports) + list(p.var_ports)) ports = [PortInitializer(pt.name, pt.shape, - getattr(pm, pt.name).d_type, + self._get_port_dtype(pt, pm), pt.__class__.__name__, pp_ch_size) for pt in ports] # Assigns initializers to builder @@ -497,6 +497,23 @@ def _get_channel_type(src: ty.Type[AbstractProcessModel], f"'({src.__name__}, {dst.__name__})' yet." ) + @staticmethod + def _get_port_dtype(port: AbstractPort, + proc_model: ty.Type[AbstractProcessModel]) -> type: + """Returns the type of a port, as specified in the corresponding + ProcessModel.""" + + # In-, Out-, Ref- and explicit VarPorts + if hasattr(proc_model, port.name): + return getattr(proc_model, port.name).d_type + # Implicitly created VarPorts + elif isinstance(port, VarPort): + return getattr(proc_model, port.var.name).d_type + # Port has different name in Process and ProcessModel + else: + raise AssertionError("Port {!r} not found in " + "ProcessModel {!r}".format(port, proc_model)) + # ToDo: (AW) Fix hard-coded hacks in this method and extend to other # channel types def _create_channel_builders(self, proc_map: PROC_MAP) \ @@ -526,7 +543,7 @@ def _create_channel_builders(self, proc_map: PROC_MAP) \ # Find destination ports for each source port for src_pt in src_ports: # Create PortInitializer for source port - src_pt_dtype = getattr(src_pm, src_pt.name).d_type + src_pt_dtype = self._get_port_dtype(src_pt, src_pm) src_pt_init = PortInitializer( src_pt.name, src_pt.shape, src_pt_dtype, src_pt.__class__.__name__, ch_size) @@ -541,7 +558,7 @@ def _create_channel_builders(self, proc_map: PROC_MAP) \ # Find appropriate channel type ch_type = self._get_channel_type(src_pm, dst_pm) # Create PortInitializer for destination port - dst_pt_d_type = getattr(dst_pm, dst_pt.name).d_type + dst_pt_d_type = self._get_port_dtype(dst_pt, dst_pm) dst_pt_init = PortInitializer( dst_pt.name, dst_pt.shape, dst_pt_d_type, dst_pt.__class__.__name__, ch_size) diff --git a/lava/magma/compiler/utils.py b/lava/magma/compiler/utils.py index fbf76d07b..52a636670 100644 --- a/lava/magma/compiler/utils.py +++ b/lava/magma/compiler/utils.py @@ -1,8 +1,6 @@ import typing as ty from dataclasses import dataclass -import numpy as np - @dataclass class VarInitializer: @@ -16,6 +14,6 @@ class VarInitializer: class PortInitializer: name: str shape: ty.Tuple[int, ...] - d_type: ty.Type[np.intc] + d_type: type port_type: str size: int diff --git a/lava/magma/core/process/ports/ports.py b/lava/magma/core/process/ports/ports.py index 32f609e36..4913fd86e 100644 --- a/lava/magma/core/process/ports/ports.py +++ b/lava/magma/core/process/ports/ports.py @@ -391,10 +391,15 @@ def connect_var(self, variables: ty.Union[Var, ty.List[Var]]): # Create a VarPort to wrap Var vp = VarPort(v) # Propagate name and parent process of Var to VarPort - vp.name = v.name + "_port" + vp.name = "_" + v.name + "_implicit_port" if v.process is not None: # Only assign when parent process is already assigned vp.process = v.process + # VarPort Name could shadow existing attribute + if hasattr(v.process, vp.name): + raise AssertionError( + "Name of implicit VarPort might conflict" + " with existing attribute.") var_ports.append(vp) # Connect RefPort to VarPorts that wrap Vars self.connect(var_ports) diff --git a/tests/magma/compiler/test_compiler.py b/tests/magma/compiler/test_compiler.py index b8409015f..8ce317624 100644 --- a/tests/magma/compiler/test_compiler.py +++ b/tests/magma/compiler/test_compiler.py @@ -13,8 +13,9 @@ from lava.magma.core.sync.domain import SyncDomain from lava.magma.core.sync.protocol import AbstractSyncProtocol from lava.magma.core.sync.protocols.async_protocol import AsyncProtocol -from lava.magma.core.process.ports.ports import InPort, OutPort -from lava.magma.core.model.py.ports import PyInPort, PyOutPort +from lava.magma.core.process.ports.ports import ( + InPort, OutPort, RefPort, VarPort) +from lava.magma.core.model.py.ports import PyInPort, PyOutPort, PyRefPort from lava.magma.core.run_configs import RunConfig from lava.magma.core.model.py.type import LavaPyType from lava.magma.core.process.variable import Var, VarServer @@ -29,6 +30,7 @@ def __init__(self, **kwargs): # Use ReduceOp to allow for multiple input connections self.inp = InPort(shape=(1,), reduce_op=ReduceSum) self.out = OutPort(shape=(1,)) + self.ref = RefPort(shape=(10,)) # Another minimal process (does not matter that it's identical to ProcA) @@ -39,6 +41,7 @@ def __init__(self, **kwargs): self.inp = InPort(shape=(1,), reduce_op=ReduceSum) self.out = OutPort(shape=(1,)) self.some_var = Var((10,), init=10) + self.var_port = VarPort(self.some_var) # Another minimal process (does not matter that it's identical to ProcA) @@ -74,6 +77,7 @@ def runtime_service(self): class PyProcModelA(AbstractPyProcessModel): inp: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int) out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int) + ref: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int) def run(self): pass @@ -86,6 +90,7 @@ class PyProcModelB(AbstractPyProcessModel): inp: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int) out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int) some_var: int = LavaPyType(int, int) + var_port: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int) def run(self): pass @@ -262,6 +267,30 @@ def test_find_process_circular(self): self.assertEqual(set(procs4), all_procs) self.assertEqual(set(procs6), all_procs) + def test_find_process_ref_ports(self): + """Checks finding all processes for RefPort connection. + [p1 -> ref/var -> p2 -> in/out -> p3]""" + + # Create processes + p1, p2, p3 = ProcA(), ProcB(), ProcC() + + # Connect p1 (RefPort) with p2 (VarPort) + p1.ref.connect(p2.var_port) + # Connect p2 (OutPort) with p3 (InPort) + p2.out.connect(p3.inp) + + # Regardless where we start searching... + c = Compiler() + procs1 = c._find_processes(p1) + procs2 = c._find_processes(p2) + procs3 = c._find_processes(p3) + + # ...we will find all of them + all_procs = {p1, p2, p3} + self.assertEqual(set(procs1), all_procs) + self.assertEqual(set(procs2), all_procs) + self.assertEqual(set(procs3), all_procs) + def test_find_proc_models(self): """Check finding of ProcModels that implement a Process.""" @@ -658,6 +687,66 @@ def test_create_channel_builders_hierarchical_process(self): self.assertEqual(chb[0].src_process, p.procs.proc1) self.assertEqual(chb[0].dst_process, p.procs.proc2) + def test_create_channel_builders_ref_ports(self): + """Checks creation of channel builders when a process is connected + using a RefPort to another process VarPort.""" + + # Create a process with a RefPort (source) + src = ProcA() + + # Create a process with a var (destination) + dst = ProcB() + + # Connect them using RefPort and VarPort + src.ref.connect(dst.var_port) + + # Create a manual proc_map + proc_map = { + src: PyProcModelA, + dst: PyProcModelB + } + + # Create channel builders + c = Compiler() + cbs = c._create_channel_builders(proc_map) + + # This should result in 1 channel builder + from lava.magma.compiler.builder import ChannelBuilderMp + self.assertEqual(len(cbs), 1) + self.assertIsInstance(cbs[0], ChannelBuilderMp) + self.assertEqual(cbs[0].src_process, src) + self.assertEqual(cbs[0].dst_process, dst) + + def test_create_channel_builders_ref_ports_implicit(self): + """Checks creation of channel builders when a process is connected + using a RefPort to another process Var (implicit VarPort).""" + + # Create a process with a RefPort (source) + src = ProcA() + + # Create a process with a var (destination) + dst = ProcB() + + # Connect them using RefPort and Var (creates implicitly a VarPort) + src.ref.connect_var(dst.some_var) + + # Create a manual proc_map + proc_map = { + src: PyProcModelA, + dst: PyProcModelB + } + + # Create channel builders + c = Compiler() + cbs = c._create_channel_builders(proc_map) + + # This should result in 1 channel builder + from lava.magma.compiler.builder import ChannelBuilderMp + self.assertEqual(len(cbs), 1) + self.assertIsInstance(cbs[0], ChannelBuilderMp) + self.assertEqual(cbs[0].src_process, src) + self.assertEqual(cbs[0].dst_process, dst) + # ToDo: (AW) @YS/@JM Please fix unit test by passing run_srv_builders to # _create_exec_vars when ready def test_create_py_exec_vars(self): diff --git a/tests/magma/core/process/test_ports.py b/tests/magma/core/process/test_ports.py index a395b22a8..7c611f25a 100644 --- a/tests/magma/core/process/test_ports.py +++ b/tests/magma/core/process/test_ports.py @@ -1,7 +1,7 @@ # Copyright (C) 2021 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause import unittest - +from lava.magma.core.process.process import AbstractProcess from lava.magma.core.process.ports.ports import ( InPort, OutPort, @@ -245,9 +245,68 @@ def test_connect_RefPort_to_Var(self): # In this case, the VarPort inherits its name and parent process from # the Var it wraps - self.assertEqual(vp.name, v.name + "_port") + self.assertEqual(vp.name, "_" + v.name + "_implicit_port") # (We can't check for the same parent process here because it has not - # been assigned ot the Var yet) + # been assigned to the Var yet) + + def test_connect_RefPort_to_Var_process(self): + """Checks connecting RefPort implicitly to Var, with registered + processes.""" + + # Create a mock parent process + class VarProcess(AbstractProcess): + ... + + # Create a Var and RefPort... + v = Var((1, 2, 3)) + rp = RefPort((1, 2, 3)) + + # ...register a process for the Var + v.process = VarProcess() + + # ...then connect them directly via connect_var(..) + rp.connect_var(v) + + # This has the same effect as connecting a RefPort explicitly via a + # VarPort to a Var... + self.assertEqual(rp.get_dst_vars(), [v]) + # ... but still creates a VarPort implicitly + vp = rp.get_dst_ports()[0] + self.assertIsInstance(vp, VarPort) + # ... which wraps the original Var + self.assertEqual(vp.var, v) + + # In this case, the VarPort inherits its name and parent process from + # the Var it wraps + self.assertEqual(vp.name, "_" + v.name + "_implicit_port") + self.assertEqual(vp.process, v.process) + + def test_connect_RefPort_to_Var_process_conflict(self): + """Checks connecting RefPort implicitly to Var, with registered + processes and conflicting names. -> AssertionError""" + + # Create a mock parent process + class VarProcess(AbstractProcess): + # Attribute is named like our implicit VarPort after creation + _existing_attr_implicit_port = None + + # Create a Var and RefPort... + v = Var((1, 2, 3)) + rp = RefPort((1, 2, 3)) + + # Create a Var and RefPort... + v = Var((1, 2, 3)) + rp = RefPort((1, 2, 3)) + + # ...register a process for the Var and name it so it conflicts with + # the attribute ov VarProcess (very unlikely to happen) + v.process = VarProcess() + v.name = "existing_attr" + + # ... and connect it directly via connect_var(..) + # The naming conflict should raise an AssertionError + with self.assertRaises(AssertionError): + rp.connect_var(v) def test_connect_RefPort_to_many_Vars(self): """Checks that RefPort can be connected to many Vars."""