diff --git a/src/lava/magma/compiler/builder.py b/src/lava/magma/compiler/builder.py index db74adda6..11fed95d5 100644 --- a/src/lava/magma/compiler/builder.py +++ b/src/lava/magma/compiler/builder.py @@ -19,13 +19,17 @@ import numpy as np from dataclasses import dataclass - from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort from lava.magma.core.model.py.model import AbstractPyProcessModel from lava.magma.core.model.py.type import LavaPyType -from lava.magma.compiler.utils import VarInitializer, PortInitializer -from lava.magma.core.model.py.ports import AbstractPyPort, \ - PyInPort, PyOutPort, PyRefPort +from lava.magma.compiler.utils import VarInitializer, PortInitializer, \ + VarPortInitializer +from lava.magma.core.model.py.ports import ( + AbstractPyPort, + PyInPort, + PyOutPort, + PyRefPort, PyVarPort, +) from lava.magma.compiler.channels.interfaces import AbstractCspPort, Channel, \ ChannelType @@ -91,6 +95,8 @@ def __init__( self._model_id = model_id self.vars: ty.Dict[str, VarInitializer] = {} self.py_ports: ty.Dict[str, PortInitializer] = {} + self.ref_ports: ty.Dict[str, PortInitializer] = {} + self.var_ports: ty.Dict[str, VarPortInitializer] = {} self.csp_ports: ty.Dict[str, ty.List[AbstractCspPort]] = {} self.csp_rs_send_port: ty.Dict[str, CspSendPort] = {} self.csp_rs_recv_port: ty.Dict[str, CspRecvPort] = {} @@ -167,6 +173,8 @@ def check_all_vars_and_ports_set(self): if ( attr_name not in self.vars and attr_name not in self.py_ports + and attr_name not in self.ref_ports + and attr_name not in self.var_ports ): raise AssertionError( f"No LavaPyType '{attr_name}' found in ProcModel " @@ -207,7 +215,9 @@ def check_lava_py_types(self): # ToDo: Also check that Vars are initializable with var.value provided def set_variables(self, variables: ty.List[VarInitializer]): - """Set variables list + """Appends the given list of variables to the ProcessModel. Used by the + compiler to create a ProcessBuilder during the compilation of + ProcessModels. Parameters ---------- @@ -220,7 +230,9 @@ def set_variables(self, variables: ty.List[VarInitializer]): self.vars.update(new_vars) def set_py_ports(self, py_ports: ty.List[PortInitializer], check=True): - """Set py_ports + """Appends the given list of PyPorts to the ProcessModel. Used by the + compiler to create a ProcessBuilder during the compilation of + ProcessModels. Parameters ---------- @@ -235,8 +247,36 @@ def set_py_ports(self, py_ports: ty.List[PortInitializer], check=True): self._check_not_assigned_yet(self.py_ports, new_ports.keys(), "ports") self.py_ports.update(new_ports) + def set_ref_ports(self, ref_ports: ty.List[PortInitializer]): + """Appends the given list of RefPorts to the ProcessModel. Used by the + compiler to create a ProcessBuilder during the compilation of + ProcessModels. + + Parameters + ---------- + ref_ports : ty.List[PortInitializer] + """ + self._check_members_exist(ref_ports, "Port") + new_ports = {p.name: p for p in ref_ports} + self._check_not_assigned_yet(self.ref_ports, new_ports.keys(), "ports") + self.ref_ports.update(new_ports) + + def set_var_ports(self, var_ports: ty.List[VarPortInitializer]): + """Appends the given list of VarPorts to the ProcessModel. Used by the + compiler to create a ProcessBuilder during the compilation of + ProcessModels. + + Parameters + ---------- + var_ports : ty.List[VarPortInitializer] + """ + new_ports = {p.name: p for p in var_ports} + self._check_not_assigned_yet(self.var_ports, new_ports.keys(), "ports") + self.var_ports.update(new_ports) + def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]): - """Set CSP Ports + """Appends the given list of CspPorts to the ProcessModel. Used by the + runtime to configure csp ports during initialization (_build_channels). Parameters ---------- @@ -253,22 +293,18 @@ def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]): new_ports.setdefault(p.name, []).extend( p if isinstance(p, list) else [p] ) - self._check_not_assigned_yet( - self.csp_ports, new_ports.keys(), "csp_ports" - ) + # Check that there's a PyPort for each new CspPort proc_name = self.proc_model.implements_process.__name__ for port_name in new_ports: if not hasattr(self.proc_model, port_name): - raise AssertionError( - "PyProcessModel '{}' has \ - no port named '{}'.".format( - proc_name, port_name - ) - ) - # Set new CspPorts - for key, ports in new_ports.items(): - self.csp_ports.setdefault(key, []).extend(ports) + raise AssertionError("PyProcessModel '{}' has \ + no port named '{}'.".format(proc_name, port_name)) + + if port_name in self.csp_ports: + self.csp_ports[port_name].extend(new_ports[port_name]) + else: + self.csp_ports[port_name] = new_ports[port_name] def set_rs_csp_ports(self, csp_ports: ty.List[AbstractCspPort]): """Set RS CSP Ports @@ -288,6 +324,9 @@ def _get_lava_type(self, name: str) -> LavaPyType: return getattr(self.proc_model, name) # ToDo: Need to differentiate signed and unsigned variable precisions + # TODO: (PP) Combine PyPort/RefPort/VarPort initialization + # TODO: (PP) Find a cleaner way to find/address csp_send/csp_recv ports (in + # Ref/VarPort initialization) def build(self): """Builds a PyProcModel at runtime within Runtime. @@ -326,13 +365,53 @@ def build(self): csp_ports = self.csp_ports[name] if not isinstance(csp_ports, list): csp_ports = [csp_ports] - port = port_cls(pm, csp_ports, p.shape, lt.d_type) + port = port_cls(csp_ports, pm, p.shape, lt.d_type) # Create dynamic PyPort attribute on ProcModel setattr(pm, name, port) # Create private attribute for port precision # setattr(pm, "_" + name + "_p", lt.precision) + # Initialize RefPorts + for name, p in self.ref_ports.items(): + # Build RefPort + lt = self._get_lava_type(name) + port_cls = ty.cast(ty.Type[PyRefPort], lt.cls) + csp_recv = None + csp_send = None + if name in self.csp_ports: + csp_ports = self.csp_ports[name] + csp_recv = csp_ports[0] if isinstance( + csp_ports[0], CspRecvPort) else csp_ports[1] + csp_send = csp_ports[0] if isinstance( + csp_ports[0], CspSendPort) else csp_ports[1] + + port = port_cls(csp_send, csp_recv, pm, p.shape, lt.d_type) + + # Create dynamic RefPort attribute on ProcModel + setattr(pm, name, port) + + # Initialize VarPorts + for name, p in self.var_ports.items(): + # Build VarPort + if p.port_cls is None: + # VarPort is not connected + continue + port_cls = ty.cast(ty.Type[PyVarPort], p.port_cls) + csp_recv = None + csp_send = None + if name in self.csp_ports: + csp_ports = self.csp_ports[name] + csp_recv = csp_ports[0] if isinstance( + csp_ports[0], CspRecvPort) else csp_ports[1] + csp_send = csp_ports[0] if isinstance( + csp_ports[0], CspSendPort) else csp_ports[1] + port = port_cls( + p.var_name, csp_send, csp_recv, pm, p.shape, p.d_type) + + # Create dynamic VarPort attribute on ProcModel + setattr(pm, name, port) + for port in self.csp_rs_recv_port.values(): if "service_to_process_cmd" in port.name: pm.service_to_process_cmd = port diff --git a/src/lava/magma/compiler/compiler.py b/src/lava/magma/compiler/compiler.py index 2948fa0d9..2ea6a4ae8 100644 --- a/src/lava/magma/compiler/compiler.py +++ b/src/lava/magma/compiler/compiler.py @@ -24,13 +24,17 @@ from lava.magma.compiler.channels.interfaces import ChannelType from lava.magma.compiler.executable import Executable from lava.magma.compiler.node import NodeConfig, Node -from lava.magma.compiler.utils import VarInitializer, PortInitializer +from lava.magma.compiler.utils import VarInitializer, PortInitializer, \ + VarPortInitializer from lava.magma.core import resources from lava.magma.core.model.c.model import AbstractCProcessModel from lava.magma.core.model.model import AbstractProcessModel from lava.magma.core.model.nc.model import AbstractNcProcessModel from lava.magma.core.model.py.model import AbstractPyProcessModel +from lava.magma.core.model.py.ports import RefVarTypeMapping from lava.magma.core.model.sub.model import AbstractSubProcessModel +from lava.magma.core.process.ports.ports import AbstractPort, VarPort, \ + ImplicitVarPort from lava.magma.core.process.process import AbstractProcess from lava.magma.core.resources import CPU, NeuroCore from lava.magma.core.run_configs import RunConfig @@ -49,7 +53,6 @@ def __init__(self, compile_cfg: ty.Optional[ty.Dict[str, ty.Any]] = None): self._compile_config.update(compile_cfg) # ToDo: (AW) Clean this up by avoiding redundant search paths - # ToDo: (AW) @PP Please include RefPorts/VarPorts in connection tracing def _find_processes(self, proc: AbstractProcess, seen_procs: ty.List[AbstractProcess] = None) \ @@ -69,14 +72,14 @@ def _find_processes(self, new_list: ty.List[AbstractProcess] = [] # add processes connecting to the main process - for in_port in proc.in_ports: + for in_port in proc.in_ports.members + proc.var_ports.members: for con in in_port.in_connections: new_list.append(con.process) for con in in_port.out_connections: new_list.append(con.process) # add processes connecting from the main process - for out_port in proc.out_ports: + for out_port in proc.out_ports.members + proc.ref_ports.members: for con in out_port.in_connections: new_list.append(con.process) for con in out_port.out_connections: @@ -251,6 +254,37 @@ def _group_proc_by_model(proc_map: PROC_MAP) \ return grouped_models + # TODO: (PP) This currently only works for PyPorts - needs general solution + # TODO: (PP) Currently does not support 1:many/many:1 connections + @staticmethod + def _map_var_port_class(port: VarPort, + proc_groups: ty.Dict[ty.Type[AbstractProcessModel], + ty.List[AbstractProcess]]): + """Maps the port class of a given VarPort from its source RefPort. This + is needed as implicitly created VarPorts created by connecting RefPorts + directly to Vars, have no LavaType.""" + + # Get the source RefPort of the VarPort + rp = port.get_src_ports() + if len(rp) > 0: + rp = rp[0] + else: + # VarPort is not connect, hence there is no LavaType + return None + + # Get the ProcessModel of the source RefPort + r_pm = None + for pm in proc_groups: + if rp.process in proc_groups[pm]: + r_pm = pm + + # Get the LavaType of the RefPort from its ProcessModel + lt = getattr(r_pm, rp.name) + + # Return mapping of the RefPort class to VarPort class + return RefVarTypeMapping.get(lt.cls) + + # TODO: (PP) possible shorten creation of PortInitializers def _compile_proc_models( self, proc_groups: ty.Dict[ty.Type[AbstractProcessModel], @@ -271,16 +305,43 @@ def _compile_proc_models( # and Ports v = [VarInitializer(v.name, v.shape, v.init, v.id) for v in p.vars] - ports = (list(p.in_ports) + list(p.out_ports) - + list(p.ref_ports)) + ports = (list(p.in_ports) + list(p.out_ports)) ports = [PortInitializer(pt.name, pt.shape, - getattr(pm, pt.name).d_type, + self._get_port_dtype(pt, pm), pt.__class__.__name__, pp_ch_size) for pt in ports] + # Create RefPort (also use PortInitializers) + ref_ports = list(p.ref_ports) + ref_ports = [ + PortInitializer(pt.name, + pt.shape, + self._get_port_dtype(pt, pm), + pt.__class__.__name__, + pp_ch_size) for pt in ref_ports] + # Create VarPortInitializers (contain also the Var name) + var_ports = [] + for pt in list(p.var_ports): + var_ports.append( + VarPortInitializer( + pt.name, + pt.shape, + pt.var.name, + self._get_port_dtype(pt, pm), + pt.__class__.__name__, + pp_ch_size, + self._map_var_port_class(pt, proc_groups))) + + # Set implicit VarPorts (created by connecting a RefPort + # directly to a Var) as attribute to ProcessModel + if isinstance(pt, ImplicitVarPort): + setattr(pm, pt.name, pt) + # Assigns initializers to builder b.set_variables(v) b.set_py_ports(ports) + b.set_ref_ports(ref_ports) + b.set_var_ports(var_ports) b.check_all_vars_and_ports_set() py_builders[p] = b elif issubclass(pm, AbstractCProcessModel): @@ -496,6 +557,27 @@ def _get_channel_type(src: ty.Type[AbstractProcessModel], f"'({src.__name__}, {dst.__name__})' yet." ) + @staticmethod + def _get_port_dtype(port: AbstractPort, + proc_model: ty.Type[AbstractProcessModel]) -> type: + """Returns the d_type of a Process Port, as specified in the + corresponding PortImplementation of the ProcessModel implementing the + Process""" + + # In-, Out-, Ref- and explicit VarPorts + if hasattr(proc_model, port.name): + # Handle VarPorts (use dtype of corresponding Var) + if isinstance(port, VarPort): + return getattr(proc_model, port.var.name).d_type + return getattr(proc_model, port.name).d_type + # Implicitly created VarPorts + elif isinstance(port, ImplicitVarPort): + return getattr(proc_model, port.var.name).d_type + # Port has different name in Process and ProcessModel + else: + raise AssertionError("Port {!r} not found in " + "ProcessModel {!r}".format(port, proc_model)) + # ToDo: (AW) Fix hard-coded hacks in this method and extend to other # channel types def _create_channel_builders(self, proc_map: PROC_MAP) \ @@ -525,7 +607,7 @@ def _create_channel_builders(self, proc_map: PROC_MAP) \ # Find destination ports for each source port for src_pt in src_ports: # Create PortInitializer for source port - src_pt_dtype = getattr(src_pm, src_pt.name).d_type + src_pt_dtype = self._get_port_dtype(src_pt, src_pm) src_pt_init = PortInitializer( src_pt.name, src_pt.shape, src_pt_dtype, src_pt.__class__.__name__, ch_size) @@ -540,7 +622,7 @@ def _create_channel_builders(self, proc_map: PROC_MAP) \ # Find appropriate channel type ch_type = self._get_channel_type(src_pm, dst_pm) # Create PortInitializer for destination port - dst_pt_d_type = getattr(dst_pm, dst_pt.name).d_type + dst_pt_d_type = self._get_port_dtype(dst_pt, dst_pm) dst_pt_init = PortInitializer( dst_pt.name, dst_pt.shape, dst_pt_d_type, dst_pt.__class__.__name__, ch_size) @@ -548,6 +630,13 @@ def _create_channel_builders(self, proc_map: PROC_MAP) \ chb = ChannelBuilderMp( ch_type, src_p, dst_p, src_pt_init, dst_pt_init) channel_builders.append(chb) + # Create additional channel builder for every VarPort + if isinstance(dst_pt, VarPort): + # RefPort to VarPort connections need channels for + # read and write + rv_chb = ChannelBuilderMp( + ch_type, dst_p, src_p, dst_pt_init, src_pt_init) + channel_builders.append(rv_chb) return channel_builders diff --git a/src/lava/magma/compiler/utils.py b/src/lava/magma/compiler/utils.py index fbf76d07b..14da90baf 100644 --- a/src/lava/magma/compiler/utils.py +++ b/src/lava/magma/compiler/utils.py @@ -1,8 +1,6 @@ import typing as ty from dataclasses import dataclass -import numpy as np - @dataclass class VarInitializer: @@ -16,6 +14,18 @@ class VarInitializer: class PortInitializer: name: str shape: ty.Tuple[int, ...] - d_type: ty.Type[np.intc] + d_type: type + port_type: str + size: int + + +# check if can be a subclass of PortInitializer +@dataclass +class VarPortInitializer: + name: str + shape: ty.Tuple[int, ...] + var_name: str + d_type: type port_type: str size: int + port_cls: type diff --git a/src/lava/magma/core/model/interfaces.py b/src/lava/magma/core/model/interfaces.py index d8e3b918d..4170f89a6 100644 --- a/src/lava/magma/core/model/interfaces.py +++ b/src/lava/magma/core/model/interfaces.py @@ -2,33 +2,33 @@ # SPDX-License-Identifier: BSD-3-Clause # See: https://spdx.org/licenses/ import typing as ty -from abc import ABC - +from abc import ABC, abstractmethod from lava.magma.compiler.channels.interfaces import AbstractCspPort -# ToDo: (AW) This type hierarchy is still not clean. csp_port could be a -# CspSendPort or CspRecvPort so down-stream classes can't do proper type -# inference to determine if there is a send/peek/recv/probe method. class AbstractPortImplementation(ABC): def __init__( self, process_model: "AbstractProcessModel", # noqa: F821 - csp_ports: ty.List[AbstractCspPort] = [], shape: ty.Tuple[int, ...] = tuple(), d_type: type = int, ): self._process_model = process_model - self._csp_ports = ( - csp_ports if isinstance(csp_ports, list) else [csp_ports] - ) self._shape = shape self._d_type = d_type + @property + @abstractmethod + def csp_ports(self) -> ty.List[AbstractCspPort]: + """Returns all csp ports of the port.""" + pass + def start(self): - for csp_port in self._csp_ports: + """Start all csp ports.""" + for csp_port in self.csp_ports: csp_port.start() def join(self): - for csp_port in self._csp_ports: + """Join all csp ports""" + for csp_port in self.csp_ports: csp_port.join() diff --git a/src/lava/magma/core/model/py/model.py b/src/lava/magma/core/model/py/model.py index 4652eace5..8f2e24ec3 100644 --- a/src/lava/magma/core/model/py/model.py +++ b/src/lava/magma/core/model/py/model.py @@ -8,7 +8,7 @@ from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort from lava.magma.core.model.model import AbstractProcessModel -from lava.magma.core.model.py.ports import AbstractPyPort +from lava.magma.core.model.py.ports import AbstractPyPort, PyVarPort from lava.magma.runtime.mgmt_token_enums import ( enum_to_np, MGMT_COMMAND, @@ -37,12 +37,16 @@ def __init__(self): self.process_to_service_data: ty.Optional[CspSendPort] = None self.service_to_process_data: ty.Optional[CspRecvPort] = None self.py_ports: ty.List[AbstractPyPort] = [] + self.var_ports: ty.List[PyVarPort] = [] self.var_id_to_var_map: ty.Dict[int, ty.Any] = {} def __setattr__(self, key: str, value: ty.Any): self.__dict__[key] = value if isinstance(value, AbstractPyPort): self.py_ports.append(value) + # Store all VarPorts for efficient RefPort -> VarPort handling + if isinstance(value, PyVarPort): + self.var_ports.append(value) def start(self): self.service_to_process_cmd.start() @@ -92,9 +96,6 @@ def run_lrn(self): def run_post_mgmt(self): pass - def run_host_mgmt(self): - pass - def pre_guard(self): pass @@ -104,64 +105,96 @@ def lrn_guard(self): def post_guard(self): pass - def host_guard(self): - pass - + # TODO: (PP) need to handle PAUSE command def run(self): + """Retrieves commands from the runtime service to iterate through the + phases of Loihi and calls their corresponding methods of the + ProcessModels. The phase is retrieved from runtime service + (service_to_process_cmd). After calling the method of a phase of all + ProcessModels the runtime service is informed about completion. The + loop ends when the STOP command is received.""" while True: + # Probe if there is a new command from the runtime service if self.service_to_process_cmd.probe(): phase = self.service_to_process_cmd.recv() if np.array_equal(phase, MGMT_COMMAND.STOP): self.process_to_service_ack.send(MGMT_RESPONSE.TERMINATED) self.join() return + # Spiking phase - increase time step if np.array_equal(phase, PyLoihiProcessModel.Phase.SPK): self.current_ts += 1 self.run_spk() + self.process_to_service_ack.send(MGMT_RESPONSE.DONE) + # Pre-management phase elif np.array_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT): + # Enable via guard method if self.pre_guard(): self.run_pre_mgmt() - self._handle_get_set_var() + self.process_to_service_ack.send(MGMT_RESPONSE.DONE) + # Handle VarPort requests from RefPorts + if len(self.var_ports) > 0: + self._handle_var_ports() + # Learning phase elif np.array_equal(phase, PyLoihiProcessModel.Phase.LRN): + # Enable via guard method if self.lrn_guard(): self.run_lrn() + self.process_to_service_ack.send(MGMT_RESPONSE.DONE) + # Post-management phase elif np.array_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT): + # Enable via guard method if self.post_guard(): self.run_post_mgmt() - self._handle_get_set_var() + self.process_to_service_ack.send(MGMT_RESPONSE.DONE) + # Handle VarPort requests from RefPorts + if len(self.var_ports) > 0: + self._handle_var_ports() + # Host phase - called at the last time step before STOP elif np.array_equal(phase, PyLoihiProcessModel.Phase.HOST): - if self.host_guard(): - self.run_host_mgmt() + # Handle get/set Var requests from runtime service + self._handle_get_set_var() else: raise ValueError(f"Wrong Phase Info Received : {phase}") - self.process_to_service_ack.send(MGMT_RESPONSE.DONE) - else: - self._handle_get_set_var() + # FIXME: (PP) might not be able to perform get/set during pause def _handle_get_set_var(self): - while self.service_to_process_req.probe(): - req_port: CspRecvPort = self.service_to_process_req - request: np.ndarray = req_port.recv() - if np.array_equal(request, REQ_TYPE.GET): - self._handle_get_var() - elif np.array_equal(request, REQ_TYPE.SET): - self._handle_set_var() - else: - raise RuntimeError(f"Unknown request type {request}") + """Handles all get/set Var requests from the runtime service and calls + the corresponding handling methods. The loop ends upon a + new command from runtime service after all get/set Var requests have + been handled.""" + while True: + # Probe if there is a get/set Var request from runtime service + if self.service_to_process_req.probe(): + # Get the type of the request + request = self.service_to_process_req.recv() + if np.array_equal(request, REQ_TYPE.GET): + self._handle_get_var() + elif np.array_equal(request, REQ_TYPE.SET): + self._handle_set_var() + else: + raise RuntimeError(f"Unknown request type {request}") + + # End if another command from runtime service arrives + if self.service_to_process_cmd.probe(): + return def _handle_get_var(self): - # 1. Recv Var ID - req_port: CspRecvPort = self.service_to_process_req - var_id: int = req_port.recv()[0].item() - var_name: str = self.var_id_to_var_map[var_id] - var: ty.Any = getattr(self, var_name) + """Handles the get Var command from runtime service.""" + # 1. Receive Var ID and retrieve the Var + var_id = self.service_to_process_req.recv()[0].item() + var_name = self.var_id_to_var_map[var_id] + var = getattr(self, var_name) # 2. Send Var data - data_port: CspSendPort = self.process_to_service_data + data_port = self.process_to_service_data + # Header corresponds to number of values + # Data is either send once (for int) or one by one (array) if isinstance(var, int) or isinstance(var, np.integer): data_port.send(enum_to_np(1)) data_port.send(enum_to_np(var)) elif isinstance(var, np.ndarray): + # FIXME: send a whole vector (also runtime_service.py) var_iter = np.nditer(var) num_items: np.integer = np.prod(var.shape) data_port.send(enum_to_np(num_items)) @@ -169,25 +202,28 @@ def _handle_get_var(self): data_port.send(enum_to_np(value)) def _handle_set_var(self): - # 1. Recv Var ID - req_port: CspRecvPort = self.service_to_process_req - var_id: int = req_port.recv()[0].item() - var_name: str = self.var_id_to_var_map[var_id] - var: ty.Any = getattr(self, var_name) - - # 2. Recv Var data - data_port: CspRecvPort = self.service_to_process_data + """Handles the set Var command from runtime service.""" + # 1. Receive Var ID and retrieve the Var + var_id = self.service_to_process_req.recv()[0].item() + var_name = self.var_id_to_var_map[var_id] + var = getattr(self, var_name) + + # 2. Receive Var data + data_port = self.service_to_process_data if isinstance(var, int) or isinstance(var, np.integer): - data_port.recv() # Ignore as this will be 1 (num_items) + # First item is number of items (1) - not needed + data_port.recv() + # Data to set buffer = data_port.recv()[0] if isinstance(var, int): setattr(self, var_name, buffer.item()) else: setattr(self, var_name, buffer.astype(var.dtype)) elif isinstance(var, np.ndarray): + # First item is number of items num_items = data_port.recv()[0] var_iter = np.nditer(var, op_flags=['readwrite']) - + # Set data one by one for i in var_iter: if num_items == 0: break @@ -195,3 +231,17 @@ def _handle_set_var(self): i[...] = data_port.recv()[0] else: raise RuntimeError("Unsupported type") + + # TODO: (PP) use select(..) to service VarPorts instead of a loop + def _handle_var_ports(self): + """Handles read/write requests on any VarPorts. The loop ends upon a + new command from runtime service after all VarPort service requests have + been handled.""" + while True: + # Loop through read/write requests of each VarPort + for vp in self.var_ports: + vp.service() + + # End if another command from runtime service arrives + if self.service_to_process_cmd.probe(): + return diff --git a/src/lava/magma/core/model/py/ports.py b/src/lava/magma/core/model/py/ports.py index 4169e0966..0b9377e24 100644 --- a/src/lava/magma/core/model/py/ports.py +++ b/src/lava/magma/core/model/py/ports.py @@ -3,16 +3,21 @@ # See: https://spdx.org/licenses/ import typing as ty from abc import abstractmethod -from enum import Enum import functools as ft - import numpy as np +from lava.magma.compiler.channels.interfaces import AbstractCspPort +from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort from lava.magma.core.model.interfaces import AbstractPortImplementation +from lava.magma.runtime.mgmt_token_enums import enum_to_np class AbstractPyPort(AbstractPortImplementation): - pass + @property + @abstractmethod + def csp_ports(self) -> ty.List[AbstractCspPort]: + """Returns all csp ports of the port.""" + pass class PyInPort(AbstractPyPort): @@ -25,6 +30,15 @@ class PyInPort(AbstractPyPort): SCALAR_DENSE: ty.Type["PyInPortScalarDense"] = None SCALAR_SPARSE: ty.Type["PyInPortScalarSparse"] = None + def __init__(self, csp_recv_ports: ty.List[CspRecvPort], *args): + self._csp_recv_ports = csp_recv_ports + super().__init__(*args) + + @property + def csp_ports(self) -> ty.List[AbstractCspPort]: + """Returns all csp ports of the port.""" + return self._csp_recv_ports + @abstractmethod def recv(self): pass @@ -41,14 +55,14 @@ class PyInPortVectorDense(PyInPort): def recv(self) -> np.ndarray: return ft.reduce( lambda acc, csp_port: acc + csp_port.recv(), - self._csp_ports, + self._csp_recv_ports, np.zeros(self._shape, self._d_type), ) def peek(self) -> np.ndarray: return ft.reduce( lambda acc, csp_port: acc + csp_port.peek(), - self._csp_ports, + self._csp_recv_ports, np.zeros(self._shape, self._d_type), ) @@ -83,14 +97,6 @@ def peek(self) -> ty.Tuple[int, int]: PyInPort.SCALAR_SPARSE = PyInPortScalarSparse -# ToDo: Remove... not needed anymore -class _PyInPort(Enum): - VEC_DENSE = PyInPortVectorDense - VEC_SPARSE = PyInPortVectorSparse - SCALAR_DENSE = PyInPortScalarDense - SCALAR_SPARSE = PyInPortScalarSparse - - class PyOutPort(AbstractPyPort): """Python implementation of OutPort used within AbstractPyProcessModels.""" @@ -99,6 +105,15 @@ class PyOutPort(AbstractPyPort): SCALAR_DENSE: ty.Type["PyOutPortScalarDense"] = None SCALAR_SPARSE: ty.Type["PyOutPortScalarSparse"] = None + def __init__(self, csp_send_ports: ty.List[CspSendPort], *args): + self._csp_send_ports = csp_send_ports + super().__init__(*args) + + @property + def csp_ports(self) -> ty.List[AbstractCspPort]: + """Returns all csp ports of the port.""" + return self._csp_send_ports + @abstractmethod def send(self, data: ty.Union[np.ndarray, int]): pass @@ -110,7 +125,7 @@ def flush(self): class PyOutPortVectorDense(PyOutPort): def send(self, data: np.ndarray): """Sends data only if port is not dangling.""" - for csp_port in self._csp_ports: + for csp_port in self._csp_send_ports: csp_port.send(data) @@ -135,12 +150,9 @@ def send(self, data: int, idx: int): PyOutPort.SCALAR_SPARSE = PyOutPortScalarSparse -# ToDo: Remove... not needed anymore -class _PyOutPort(Enum): - VEC_DENSE = PyOutPortVectorDense - VEC_SPARSE = PyOutPortVectorSparse - SCALAR_DENSE = PyOutPortScalarDense - SCALAR_SPARSE = PyOutPortScalarSparse +class VarPortCmd: + GET = enum_to_np(0) + SET = enum_to_np(1) class PyRefPort(AbstractPyPort): @@ -151,6 +163,22 @@ class PyRefPort(AbstractPyPort): SCALAR_DENSE: ty.Type["PyRefPortScalarDense"] = None SCALAR_SPARSE: ty.Type["PyRefPortScalarSparse"] = None + def __init__(self, + csp_send_port: ty.Optional[CspSendPort], + csp_recv_port: ty.Optional[CspRecvPort], *args): + self._csp_recv_port = csp_recv_port + self._csp_send_port = csp_send_port + super().__init__(*args) + + @property + def csp_ports(self) -> ty.List[AbstractCspPort]: + """Returns all csp ports of the port.""" + if self._csp_send_port is not None and self._csp_recv_port is not None: + return [self._csp_send_port, self._csp_recv_port] + else: + # In this case the port was not connected + return [] + def read( self, ) -> ty.Union[ @@ -172,10 +200,21 @@ def write( class PyRefPortVectorDense(PyRefPort): def read(self) -> np.ndarray: - pass + """Requests the data from a VarPort and returns the data.""" + if self._csp_send_port and self._csp_recv_port: + header = np.ones(self._csp_send_port.shape) * VarPortCmd.GET + self._csp_send_port.send(header) + + return self._csp_recv_port.recv() + + return np.zeros(self._shape, self._d_type) def write(self, data: np.ndarray): - pass + """Sends the data to a VarPort to set its Var.""" + if self._csp_send_port: + header = np.ones(self._csp_send_port.shape) * VarPortCmd.SET + self._csp_send_port.send(header) + self._csp_send_port.send(data) class PyRefPortVectorSparse(PyRefPort): @@ -208,9 +247,98 @@ def write(self, data: int, idx: int): PyRefPort.SCALAR_SPARSE = PyRefPortScalarSparse -# ToDo: Remove... not needed anymore -class _PyRefPort(Enum): - VEC_DENSE = PyRefPortVectorDense - VEC_SPARSE = PyRefPortVectorSparse - SCALAR_DENSE = PyRefPortScalarDense - SCALAR_SPARSE = PyRefPortScalarSparse +class PyVarPort(AbstractPyPort): + """Python implementation of VarPort used within AbstractPyProcessModel. + """ + + VEC_DENSE: ty.Type["PyVarPortVectorDense"] = None + VEC_SPARSE: ty.Type["PyVarPortVectorSparse"] = None + SCALAR_DENSE: ty.Type["PyVarPortScalarDense"] = None + SCALAR_SPARSE: ty.Type["PyVarPortScalarSparse"] = None + + def __init__(self, + var_name: str, + csp_send_port: ty.Optional[CspSendPort], + csp_recv_port: ty.Optional[CspRecvPort], *args): + self._csp_recv_port = csp_recv_port + self._csp_send_port = csp_send_port + self.var_name = var_name + super().__init__(*args) + + @property + def csp_ports(self) -> ty.List[AbstractCspPort]: + """Returns all csp ports of the port.""" + if self._csp_send_port is not None and self._csp_recv_port is not None: + return [self._csp_send_port, self._csp_recv_port] + else: + # In this case the port was not connected + return [] + + def service(self): + pass + + +class PyVarPortVectorDense(PyVarPort): + def service(self): + """Sets the received value to the given var or sends the value of the + var to the csp_send_port, depending on the received header information + of the csp_recv_port.""" + + # Inspect incoming data + if self._csp_send_port is not None and self._csp_recv_port is not None: + if self._csp_recv_port.probe(): + cmd = enum_to_np(self._csp_recv_port.recv()[0]) + + # Set the value of the Var with the given data + if np.array_equal(cmd, VarPortCmd.SET): + data = self._csp_recv_port.recv() + setattr(self._process_model, self.var_name, data) + elif np.array_equal(cmd, VarPortCmd.GET): + data = getattr(self._process_model, self.var_name) + self._csp_send_port.send(data) + else: + raise ValueError(f"Wrong Command Info Received : {cmd}") + + +class PyVarPortVectorSparse(PyVarPort): + def recv(self) -> ty.Tuple[np.ndarray, np.ndarray]: + pass + + def peek(self) -> ty.Tuple[np.ndarray, np.ndarray]: + pass + + +class PyVarPortScalarDense(PyVarPort): + def recv(self) -> int: + pass + + def peek(self) -> int: + pass + + +class PyVarPortScalarSparse(PyVarPort): + def recv(self) -> ty.Tuple[int, int]: + pass + + def peek(self) -> ty.Tuple[int, int]: + pass + + +PyVarPort.VEC_DENSE = PyVarPortVectorDense +PyVarPort.VEC_SPARSE = PyVarPortVectorSparse +PyVarPort.SCALAR_DENSE = PyVarPortScalarDense +PyVarPort.SCALAR_SPARSE = PyVarPortScalarSparse + + +class RefVarTypeMapping: + """Class to get the mapping of PyRefPort types to PyVarPortTypes.""" + + mapping: ty.Dict[PyRefPort, PyVarPort] = { + PyRefPortVectorDense: PyVarPortVectorDense, + PyRefPortVectorSparse: PyVarPortVectorSparse, + PyRefPortScalarDense: PyVarPortScalarDense, + PyRefPortScalarSparse: PyVarPortScalarSparse} + + @classmethod + def get(cls, ref_port: PyRefPort): + return cls.mapping[ref_port] diff --git a/src/lava/magma/core/process/ports/ports.py b/src/lava/magma/core/process/ports/ports.py index 32f609e36..803347d89 100644 --- a/src/lava/magma/core/process/ports/ports.py +++ b/src/lava/magma/core/process/ports/ports.py @@ -347,6 +347,14 @@ def connect( ----------- :param ports: The AbstractRVPort(s) to connect to. """ + for p in to_list(ports): + if not isinstance(p, RefPort) and not isinstance(p, VarPort): + raise TypeError( + "RefPorts can only be connected to RefPorts or " + "VarPorts: {!r}: {!r} -> {!r}: {!r} To connect a RefPort " + "to a Var, use ".format( + self.process.__class__.__name__, self.name, + p.process.__class__.__name__, p.name)) self._connect_forward(to_list(ports), AbstractRVPort) def connect_from(self, ports: ty.Union["RefPort", ty.List["RefPort"]]): @@ -357,6 +365,13 @@ def connect_from(self, ports: ty.Union["RefPort", ty.List["RefPort"]]): ---------- :param ports: The RefPort(s) that connect to this RefPort. """ + for p in to_list(ports): + if not isinstance(p, RefPort): + raise TypeError( + "RefPorts can only receive connections from RefPorts: " + "{!r}: {!r} -> {!r}: {!r}".format( + self.process.__class__.__name__, self.name, + p.process.__class__.__name__, p.name)) self._connect_backward(to_list(ports), RefPort) def connect_var(self, variables: ty.Union[Var, ty.List[Var]]): @@ -389,12 +404,19 @@ def connect_var(self, variables: ty.Union[Var, ty.List[Var]]): if var_shape != v.shape: raise AssertionError("All 'vars' must have same shape.") # Create a VarPort to wrap Var - vp = VarPort(v) + vp = ImplicitVarPort(v) # Propagate name and parent process of Var to VarPort - vp.name = v.name + "_port" + vp.name = "_" + v.name + "_implicit_port" if v.process is not None: # Only assign when parent process is already assigned vp.process = v.process + # VarPort name could shadow existing attribute + if hasattr(v.process, vp.name): + raise AssertionError( + "Name of implicit VarPort might conflict" + " with existing attribute.") + setattr(v.process, vp.name, vp) + v.process.var_ports.add_members({vp.name: vp}) var_ports.append(vp) # Connect RefPort to VarPorts that wrap Vars self.connect(var_ports) @@ -440,6 +462,13 @@ def connect(self, ports: ty.Union["VarPort", ty.List["VarPort"]]): ---------- :param ports: The VarPort(s) to connect to. """ + for p in to_list(ports): + if not isinstance(p, VarPort): + raise TypeError( + "VarPorts can only be connected to VarPorts: " + "{!r}: {!r} -> {!r}: {!r}".format( + self.process.__class__.__name__, self.name, + p.process.__class__.__name__, p.name)) self._connect_forward(to_list(ports), VarPort) def connect_from( @@ -452,9 +481,22 @@ def connect_from( ---------- :param ports: The AbstractRVPort(s) that connect to this VarPort. """ + for p in to_list(ports): + if not isinstance(p, RefPort) and not isinstance(p, VarPort): + raise TypeError( + "VarPorts can only receive connections from RefPorts or " + "VarPorts: {!r}: {!r} -> {!r}: {!r}".format( + self.process.__class__.__name__, self.name, + p.process.__class__.__name__, p.name)) self._connect_backward(to_list(ports), AbstractRVPort) +class ImplicitVarPort(VarPort): + """Sub class for VarPort to identify implicitly created VarPorts when + a RefPort connects directly to a Var.""" + pass + + class AbstractVirtualPort(ABC): """Abstract base class interface for any type of port that merely serves to transforms the properties of a user-defined port. @@ -543,7 +585,7 @@ def _get_new_shape(ports: ty.List[AbstractPort], axis): # Compute total size along concatenation axis total_size += shape[axis] # Extract shape dimensions other than concatenation axis - shapes_ex_axis.append(shape[:axis] + shape[axis + 1 :]) + shapes_ex_axis.append(shape[:axis] + shape[axis + 1:]) if len(shapes_ex_axis) > 1: shapes_incompatible = shapes_ex_axis[-2] != shapes_ex_axis[-1] diff --git a/src/lava/magma/runtime/runtime_service.py b/src/lava/magma/runtime/runtime_service.py index 871b2b397..819825886 100644 --- a/src/lava/magma/runtime/runtime_service.py +++ b/src/lava/magma/runtime/runtime_service.py @@ -91,155 +91,194 @@ class Phase: POST_MGMT = enum_to_np(4) HOST = enum_to_np(5) - def _next_phase(self, curr_phase): + def _next_phase(self, curr_phase, is_last_time_step: bool): + """Advances the current phase to the next phase. + On the first time step it starts with HOST phase and advances to SPK. + Afterwards it loops: SPK -> PRE_MGMT -> LRN -> POST_MGMT -> SPK + On the last time step POST_MGMT advances to HOST phase.""" if curr_phase == LoihiPyRuntimeService.Phase.SPK: return LoihiPyRuntimeService.Phase.PRE_MGMT elif curr_phase == LoihiPyRuntimeService.Phase.PRE_MGMT: return LoihiPyRuntimeService.Phase.LRN elif curr_phase == LoihiPyRuntimeService.Phase.LRN: return LoihiPyRuntimeService.Phase.POST_MGMT - elif curr_phase == LoihiPyRuntimeService.Phase.POST_MGMT: + elif curr_phase == LoihiPyRuntimeService.Phase.POST_MGMT and \ + is_last_time_step: return LoihiPyRuntimeService.Phase.HOST + elif curr_phase == LoihiPyRuntimeService.Phase.POST_MGMT and not \ + is_last_time_step: + return LoihiPyRuntimeService.Phase.SPK elif curr_phase == LoihiPyRuntimeService.Phase.HOST: return LoihiPyRuntimeService.Phase.SPK def _send_pm_cmd(self, phase: MGMT_COMMAND): + """Sends a command (phase information) to all ProcessModels.""" for send_port in self.service_to_process_cmd: send_port.send(phase) def _send_pm_req_given_model_id(self, model_id: int, *requests): - process_idx: int = self.model_ids.index(model_id) - req_port: CspSendPort = self.service_to_process_req[process_idx] + """Sends requests to a ProcessModel given by the model id.""" + process_idx = self.model_ids.index(model_id) + req_port = self.service_to_process_req[process_idx] for request in requests: req_port.send(request) - def _get_pm_resp(self, phase) -> ty.Iterable[MGMT_RESPONSE]: + def _get_pm_resp(self) -> ty.Iterable[MGMT_RESPONSE]: + """Retrieves responses of all ProcessModels.""" rcv_msgs = [] - num_responses_expected: int = len(self.model_ids) - counter: int = 0 + num_responses_expected = len(self.model_ids) + counter = 0 while counter < num_responses_expected: ptos_recv_port = self.process_to_service_ack[counter] - self._handle_get_set(phase) if ptos_recv_port.probe(): rcv_msgs.append(ptos_recv_port.recv()) counter += 1 return rcv_msgs def _relay_to_runtime_data_given_model_id(self, model_id: int): - """Relays data received from pm to runtime""" - process_idx: int = self.model_ids.index(model_id) + """Relays data received from ProcessModel given by model id to the + runtime""" + process_idx = self.model_ids.index(model_id) - data_recv_port: CspRecvPort = self.process_to_service_data[process_idx] - data_relay_port: CspSendPort = self.service_to_runtime_data - num_items: np.ndarray = data_recv_port.recv() + data_recv_port = self.process_to_service_data[process_idx] + data_relay_port = self.service_to_runtime_data + num_items = data_recv_port.recv() data_relay_port.send(num_items) for i in range(num_items[0]): data_relay_port.send(data_recv_port.recv()) def _relay_to_pm_data_given_model_id(self, model_id: int): - """Relays data received from runtime to pm""" - process_idx: int = self.model_ids.index(model_id) - - data_recv_port: CspRecvPort = self.runtime_to_service_data - data_relay_port: CspSendPort = self.service_to_process_data[process_idx] - # recv and relay num_items - num_items: np.ndarray = data_recv_port.recv() + """Relays data received from the runtime to the ProcessModel given by + the model id.""" + process_idx = self.model_ids.index(model_id) + + data_recv_port = self.runtime_to_service_data + data_relay_port = self.service_to_process_data[process_idx] + # Receive and relay number of items + num_items = data_recv_port.recv() data_relay_port.send(num_items) - # recv and relay data1, data2, ... + # Receive and relay data1, data2, ... for i in range(num_items[0].item()): data_relay_port.send(data_recv_port.recv()) def _relay_pm_ack_given_model_id(self, model_id: int): - """Relays ack received from pm to runtime""" - process_idx: int = self.model_ids.index(model_id) + """Relays ack received from ProcessModel given by model id to the + runtime.""" + process_idx = self.model_ids.index(model_id) - ack_recv_port: CspRecvPort = self.process_to_service_ack[process_idx] - ack_relay_port: CspSendPort = self.service_to_runtime_ack + ack_recv_port = self.process_to_service_ack[process_idx] + ack_relay_port = self.service_to_runtime_ack ack_relay_port.send(ack_recv_port.recv()) def run(self): - phase = LoihiPyRuntimeService.Phase.SPK + """Retrieves commands from the runtime. On STOP or PAUSE commands all + ProcessModels are notified and expected to TERMINATE or PAUSE, + respectively. Otherwise the number of time steps is received as command. + In this case iterate through the phases of the Loihi protocol until the + last time step is reached. The runtime is informed after the last time + step. The loop ends when receiving the STOP command from the runtime.""" + phase = LoihiPyRuntimeService.Phase.HOST while True: + # Probe if there is a new command from the runtime if self.runtime_to_service_cmd.probe(): command = self.runtime_to_service_cmd.recv() if np.array_equal(command, MGMT_COMMAND.STOP): + # Inform all ProcessModels about the STOP command self._send_pm_cmd(command) - rsps = self._get_pm_resp(phase) + rsps = self._get_pm_resp() for rsp in rsps: if not np.array_equal(rsp, MGMT_RESPONSE.TERMINATED): raise ValueError(f"Wrong Response Received : {rsp}") + # Inform the runtime about successful termination self.service_to_runtime_ack.send(MGMT_RESPONSE.TERMINATED) self.join() return elif np.array_equal(command, MGMT_COMMAND.PAUSE): + # Inform all ProcessModels about the PAUSE command self._send_pm_cmd(command) - rsps = self._get_pm_resp(phase) + rsps = self._get_pm_resp() for rsp in rsps: if not np.array_equal(rsp, MGMT_RESPONSE.PAUSED): raise ValueError(f"Wrong Response Received : {rsp}") + # Inform the runtime about successful pausing self.service_to_runtime_ack.send(MGMT_RESPONSE.PAUSED) break else: + # The number of time steps was received ("command") + # Start iterating through Loihi phases curr_time_step = 0 - phase = LoihiPyRuntimeService.Phase.SPK - while not np.array_equal(enum_to_np(curr_time_step), - command): + phase = LoihiPyRuntimeService.Phase.HOST + while True: + # Check if it is the last time step + is_last_ts = np.array_equal(enum_to_np(curr_time_step), + command) + # Advance to the next phase + phase = self._next_phase(phase, is_last_ts) + # Increase time step if spiking phase if np.array_equal(phase, LoihiPyRuntimeService.Phase.SPK): curr_time_step += 1 + # Inform ProcessModels about current phase self._send_pm_cmd(phase) - rsps = self._get_pm_resp(phase) - for rsp in rsps: - if not np.array_equal(rsp, MGMT_RESPONSE.DONE): - raise ValueError( - f"Wrong Response Received : {rsp}") - is_last_ts = np.array_equal(enum_to_np(curr_time_step), - command) - is_last_phase = np.array_equal(phase, - LoihiPyRuntimeService. - Phase.POST_MGMT) - if not (is_last_ts and is_last_phase): - phase = self._next_phase(phase) + # ProcessModels respond with DONE if not HOST phase + if not np.array_equal( + phase, LoihiPyRuntimeService.Phase.HOST): + rsps = self._get_pm_resp() + for rsp in rsps: + if not np.array_equal(rsp, MGMT_RESPONSE.DONE): + raise ValueError( + f"Wrong Response Received : {rsp}") + + # If HOST phase (last time step ended) break the loop + if np.array_equal( + phase, LoihiPyRuntimeService.Phase.HOST): + break + + # Inform the runtime that last time step was reached self.service_to_runtime_ack.send(MGMT_RESPONSE.DONE) + # Handle get/set Var self._handle_get_set(phase) def _handle_get_set(self, phase): - if np.array_equal(phase, LoihiPyRuntimeService.Phase.PRE_MGMT) or \ - np.array_equal(phase, LoihiPyRuntimeService.Phase.POST_MGMT): - while self.runtime_to_service_req.probe(): - request = self.runtime_to_service_req.recv() - if np.array_equal(request, REQ_TYPE.GET): - requests: ty.List[np.ndarray] = [request] - # recv model_id - model_id: int = \ - self.runtime_to_service_req.recv()[ - 0].item() - # recv var_id - requests.append( - self.runtime_to_service_req.recv()) - self._send_pm_req_given_model_id(model_id, - *requests) - - self._relay_to_runtime_data_given_model_id( - model_id) - elif np.array_equal(request, REQ_TYPE.SET): - requests: ty.List[np.ndarray] = [request] - # recv model_id - model_id: int = \ - self.runtime_to_service_req.recv()[ - 0].item() - # recv var_id - requests.append( - self.runtime_to_service_req.recv()) - self._send_pm_req_given_model_id(model_id, - *requests) - - self._relay_to_pm_data_given_model_id( - model_id) - else: - raise RuntimeError( - f"Unknown request {request}") + if np.array_equal(phase, LoihiPyRuntimeService.Phase.HOST): + while True: + if self.runtime_to_service_req.probe(): + request = self.runtime_to_service_req.recv() + if np.array_equal(request, REQ_TYPE.GET): + requests: ty.List[np.ndarray] = [request] + # recv model_id + model_id: int = \ + self.runtime_to_service_req.recv()[ + 0].item() + # recv var_id + requests.append( + self.runtime_to_service_req.recv()) + self._send_pm_req_given_model_id(model_id, + *requests) + + self._relay_to_runtime_data_given_model_id( + model_id) + elif np.array_equal(request, REQ_TYPE.SET): + requests: ty.List[np.ndarray] = [request] + # recv model_id + model_id: int = \ + self.runtime_to_service_req.recv()[ + 0].item() + # recv var_id + requests.append( + self.runtime_to_service_req.recv()) + self._send_pm_req_given_model_id(model_id, + *requests) + + self._relay_to_pm_data_given_model_id( + model_id) + else: + raise RuntimeError( + f"Unknown request {request}") + + if self.runtime_to_service_cmd.probe(): + return class LoihiCRuntimeService(AbstractRuntimeService): diff --git a/tests/lava/magma/compiler/test_compiler.py b/tests/lava/magma/compiler/test_compiler.py index 385a2dfc0..7bcfe3476 100644 --- a/tests/lava/magma/compiler/test_compiler.py +++ b/tests/lava/magma/compiler/test_compiler.py @@ -14,24 +14,27 @@ from lava.magma.core.sync.domain import SyncDomain from lava.magma.core.sync.protocol import AbstractSyncProtocol from lava.magma.core.sync.protocols.async_protocol import AsyncProtocol -from lava.magma.core.process.ports.ports import InPort, OutPort -from lava.magma.core.model.py.ports import PyInPort, PyOutPort +from lava.magma.core.process.ports.ports import ( + InPort, OutPort, RefPort, VarPort) +from lava.magma.core.model.py.ports import PyInPort, PyOutPort, PyRefPort, \ + PyVarPort from lava.magma.core.run_configs import RunConfig from lava.magma.core.model.py.type import LavaPyType from lava.magma.core.process.variable import Var, VarServer from lava.magma.core.resources import CPU -# minimal process with an InPort and OutPortA +# A minimal process (A) with an InPort, OutPort and RefPort class ProcA(AbstractProcess): def __init__(self, **kwargs): super().__init__(**kwargs) # Use ReduceOp to allow for multiple input connections self.inp = InPort(shape=(1,), reduce_op=ReduceSum) self.out = OutPort(shape=(1,)) + self.ref = RefPort(shape=(10,)) -# Another minimal process (does not matter that it's identical to ProcA) +# Another minimal process (B) with a Var and an InPort, OutPort and VarPort class ProcB(AbstractProcess): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -39,9 +42,10 @@ def __init__(self, **kwargs): self.inp = InPort(shape=(1,), reduce_op=ReduceSum) self.out = OutPort(shape=(1,)) self.some_var = Var((10,), init=10) + self.var_port = VarPort(self.some_var) -# Another minimal process (does not matter that it's identical to ProcA) +# Another minimal process (C) with an InPort and OutPort class ProcC(AbstractProcess): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -74,6 +78,7 @@ def runtime_service(self): class PyProcModelA(AbstractPyProcessModel): inp: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int) out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int) + ref: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int) def run(self): pass @@ -86,6 +91,7 @@ class PyProcModelB(AbstractPyProcessModel): inp: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int) out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int) some_var: int = LavaPyType(int, int) + var_port: PyVarPort = LavaPyType(PyVarPort.VEC_DENSE, int) def run(self): pass @@ -262,6 +268,30 @@ def test_find_process_circular(self): self.assertEqual(set(procs4), all_procs) self.assertEqual(set(procs6), all_procs) + def test_find_process_ref_ports(self): + """Checks finding all processes for RefPort connection. + [p1 -> ref/var -> p2 -> out/in -> p3]""" + + # Create processes + p1, p2, p3 = ProcA(), ProcB(), ProcC() + + # Connect p1 (RefPort) with p2 (VarPort) + p1.ref.connect(p2.var_port) + # Connect p2 (OutPort) with p3 (InPort) + p2.out.connect(p3.inp) + + # Regardless where we start searching... + c = Compiler() + procs1 = c._find_processes(p1) + procs2 = c._find_processes(p2) + procs3 = c._find_processes(p3) + + # ...we will find all of them + all_procs = {p1, p2, p3} + self.assertEqual(set(procs1), all_procs) + self.assertEqual(set(procs2), all_procs) + self.assertEqual(set(procs3), all_procs) + def test_find_proc_models(self): """Check finding of ProcModels that implement a Process.""" @@ -658,6 +688,66 @@ def test_create_channel_builders_hierarchical_process(self): self.assertEqual(chb[0].src_process, p.procs.proc1) self.assertEqual(chb[0].dst_process, p.procs.proc2) + def test_create_channel_builders_ref_ports(self): + """Checks creation of channel builders when a process is connected + using a RefPort to another process VarPort.""" + + # Create a process with a RefPort (source) + src = ProcA() + + # Create a process with a var (destination) + dst = ProcB() + + # Connect them using RefPort and VarPort + src.ref.connect(dst.var_port) + + # Create a manual proc_map + proc_map = { + src: PyProcModelA, + dst: PyProcModelB + } + + # Create channel builders + c = Compiler() + cbs = c._create_channel_builders(proc_map) + + # This should result in 2 channel builder + from lava.magma.compiler.builder import ChannelBuilderMp + self.assertEqual(len(cbs), 2) + self.assertIsInstance(cbs[0], ChannelBuilderMp) + self.assertEqual(cbs[0].src_process, src) + self.assertEqual(cbs[0].dst_process, dst) + + def test_create_channel_builders_ref_ports_implicit(self): + """Checks creation of channel builders when a process is connected + using a RefPort to another process Var (implicit VarPort).""" + + # Create a process with a RefPort (source) + src = ProcA() + + # Create a process with a var (destination) + dst = ProcB() + + # Connect them using RefPort and Var (creates implicitly a VarPort) + src.ref.connect_var(dst.some_var) + + # Create a manual proc_map + proc_map = { + src: PyProcModelA, + dst: PyProcModelB + } + + # Create channel builders + c = Compiler() + cbs = c._create_channel_builders(proc_map) + + # This should result in 2 channel builder + from lava.magma.compiler.builder import ChannelBuilderMp + self.assertEqual(len(cbs), 2) + self.assertIsInstance(cbs[0], ChannelBuilderMp) + self.assertEqual(cbs[0].src_process, src) + self.assertEqual(cbs[0].dst_process, dst) + # ToDo: (AW) @YS/@JM Please fix unit test by passing run_srv_builders to # _create_exec_vars when ready def test_create_py_exec_vars(self): diff --git a/tests/lava/magma/core/model/test_py_model.py b/tests/lava/magma/core/model/test_py_model.py index 6f9b4b083..4f9b7a9a4 100644 --- a/tests/lava/magma/core/model/test_py_model.py +++ b/tests/lava/magma/core/model/test_py_model.py @@ -8,14 +8,17 @@ from lava.magma.core.process.process import AbstractProcess from lava.magma.core.process.variable import Var -from lava.magma.core.process.ports.ports import InPort, OutPort +from lava.magma.core.process.ports.ports import InPort, OutPort, RefPort, \ + VarPort from lava.magma.core.decorator import implements, requires from lava.magma.core.resources import CPU from lava.magma.core.model.py.model import AbstractPyProcessModel from lava.magma.core.model.py.type import LavaPyType -from lava.magma.core.model.py.ports import PyInPort, PyOutPort +from lava.magma.core.model.py.ports import PyInPort, PyOutPort, PyRefPort, \ + PyVarPort -from lava.magma.compiler.utils import VarInitializer, PortInitializer +from lava.magma.compiler.utils import VarInitializer, PortInitializer, \ + VarPortInitializer from lava.magma.compiler.builder import PyProcessBuilder from lava.magma.compiler.channels.interfaces import AbstractCspPort @@ -104,7 +107,7 @@ class ProcModelForLavaPyType1(AbstractPyProcessModel): port: PyInPort = LavaPyType(123, int) # type: ignore -# A wrong ProcessModel with wrong syb type +# A wrong ProcessModel with wrong sub type @implements(proc=ProcForLavaPyType) @requires(CPU) class ProcModelForLavaPyType2(AbstractPyProcessModel): @@ -118,6 +121,24 @@ class ProcModelForLavaPyType3(AbstractPyProcessModel): port: PyInPort = LavaPyType(PyOutPort, int) +# A minimal process to test RefPorts and VarPorts +class ProcRefVar(AbstractProcess): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.ref = RefPort(shape=(3,)) + self.var = Var(shape=(3,), init=4) + self.var_port = VarPort(self.var) + + +# A minimal PyProcModel implementing ProcRefVar +@implements(proc=ProcRefVar) +@requires(CPU) +class PyProcModelRefVar(AbstractPyProcessModel): + ref: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int) + var: np.ndarray = LavaPyType(np.ndarray, np.int32) + var_port: PyVarPort = LavaPyType(PyVarPort.VEC_DENSE, int) + + class TestPyProcessBuilder(unittest.TestCase): """ProcessModels are not not created directly but through a corresponding PyProcessBuilder. Therefore, we test both classes together.""" @@ -234,7 +255,7 @@ def test_check_lava_py_types(self): InPort called 'port' """ - # Create univeral PortInitializer reflecting the 'port' in + # Create universal PortInitializer reflecting the 'port' in # ProcForLavaPyType pi = PortInitializer("port", (1,), np.intc, "InPort", 32) @@ -411,14 +432,73 @@ def test_build_with_dangling_ports(self): # Validate that the Process with no OutPorts indeed has no output # CspPort self.assertIsInstance( - pm_with_no_out_ports.in_port._csp_ports[0], FakeCspPort) - self.assertEqual(pm_with_no_out_ports.out_port._csp_ports, []) + pm_with_no_out_ports.in_port._csp_recv_ports[0], FakeCspPort) + self.assertEqual(pm_with_no_out_ports.out_port._csp_send_ports, []) # Validate that the Process with no InPorts indeed has no input # CspPort - self.assertEqual(pm_with_no_in_ports.in_port._csp_ports, []) + self.assertEqual(pm_with_no_in_ports.in_port._csp_recv_ports, []) self.assertIsInstance( - pm_with_no_in_ports.out_port._csp_ports[0], FakeCspPort) + pm_with_no_in_ports.out_port._csp_send_ports[0], FakeCspPort) + + def test_set_ref_var_ports(self): + """Check RefPorts and VarPorts can be set.""" + + # Create a new ProcBuilder + b = PyProcessBuilder(PyProcModelRefVar, 0) + + # Create Process for which we want to build PyProcModel + proc = ProcRefVar() + + # Normally, the Compiler would create PortInitializers from all + # ref ports holding only its name and shape + ports = list(proc.ref_ports) + ref_ports = [PortInitializer( + pt.name, + pt.shape, + getattr(PyProcModelRefVar, pt.name).d_type, + pt.__class__.__name__, 32) + for pt in ports] + # Similarly, the Compiler would create VarPortInitializers from all + # var ports holding only its name, shape and var_name + ports = list(proc.var_ports) + var_ports = [VarPortInitializer( + pt.name, + pt.shape, + pt.var.name, + getattr(PyProcModelRefVar, pt.name).d_type, + pt.__class__.__name__, 32, PyRefPort.VEC_DENSE) + for pt in ports] + # The Runtime, would normally create CspPorts that implement the actual + # message passing via channels between RefPorts and VarPorts. Here we + # just create some fake CspPorts for each Ref- and VarPort. + # 2 CspChannels per Ref-/VarPort. + csp_ports = [] + for port in list(ref_ports): + csp_ports.append(FakeCspPort(port.name)) + csp_ports.append(FakeCspPort(port.name)) + for port in list(var_ports): + csp_ports.append(FakeCspPort(port.name)) + csp_ports.append(FakeCspPort(port.name)) + + # During compilation, the Compiler creates and then sets + # PortInitializers and VarPortInitializers + b.set_ref_ports(ref_ports) + b.set_var_ports(var_ports) + # The Runtime sets CspPorts + b.set_csp_ports(csp_ports) + + # All the objects are converted into dictionaries to retrieve them by + # name + self.assertEqual(list(b.py_ports.values()), []) + self.assertEqual(list(b.ref_ports.values()), ref_ports) + self.assertEqual(list(b.var_ports.values()), var_ports) + self.assertEqual(list(v for vv in b.csp_ports.values() + for v in vv), csp_ports) + self.assertEqual(b.ref_ports["ref"], ref_ports[0]) + self.assertEqual(b.csp_ports["ref"], [csp_ports[0], csp_ports[1]]) + self.assertEqual(b.var_ports["var_port"], var_ports[0]) + self.assertEqual(b.csp_ports["var_port"], [csp_ports[2], csp_ports[3]]) if __name__ == "__main__": diff --git a/tests/lava/magma/core/process/test_ports.py b/tests/lava/magma/core/process/test_ports.py index a395b22a8..eed505f4a 100644 --- a/tests/lava/magma/core/process/test_ports.py +++ b/tests/lava/magma/core/process/test_ports.py @@ -1,7 +1,8 @@ # Copyright (C) 2021 Intel Corporation -# SPDX-License-Identifier: BSD-3-Clause +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ import unittest - +from lava.magma.core.process.process import AbstractProcess from lava.magma.core.process.ports.ports import ( InPort, OutPort, @@ -245,9 +246,64 @@ def test_connect_RefPort_to_Var(self): # In this case, the VarPort inherits its name and parent process from # the Var it wraps - self.assertEqual(vp.name, v.name + "_port") + self.assertEqual(vp.name, "_" + v.name + "_implicit_port") # (We can't check for the same parent process here because it has not - # been assigned ot the Var yet) + # been assigned to the Var yet) + + def test_connect_RefPort_to_Var_process(self): + """Checks connecting RefPort implicitly to Var, with registered + processes.""" + + # Create a mock parent process + class VarProcess(AbstractProcess): + ... + + # Create a Var and RefPort... + v = Var((1, 2, 3)) + rp = RefPort((1, 2, 3)) + + # ...register a process for the Var + v.process = VarProcess() + + # ...then connect them directly via connect_var(..) + rp.connect_var(v) + + # This has the same effect as connecting a RefPort explicitly via a + # VarPort to a Var... + self.assertEqual(rp.get_dst_vars(), [v]) + # ... but still creates a VarPort implicitly + vp = rp.get_dst_ports()[0] + self.assertIsInstance(vp, VarPort) + # ... which wraps the original Var + self.assertEqual(vp.var, v) + + # In this case, the VarPort inherits its name and parent process from + # the Var it wraps + self.assertEqual(vp.name, "_" + v.name + "_implicit_port") + self.assertEqual(vp.process, v.process) + + def test_connect_RefPort_to_Var_process_conflict(self): + """Checks connecting RefPort implicitly to Var, with registered + processes and conflicting names. -> AssertionError""" + + # Create a mock parent process + class VarProcess(AbstractProcess): + # Attribute is named like our implicit VarPort after creation + _existing_attr_implicit_port = None + + # Create a Var and RefPort... + v = Var((1, 2, 3)) + rp = RefPort((1, 2, 3)) + + # ...register a process for the Var and name it so it conflicts with + # the attribute of VarProcess (very unlikely to happen) + v.process = VarProcess() + v.name = "existing_attr" + + # ... and connect it directly via connect_var(..) + # The naming conflict should raise an AssertionError + with self.assertRaises(AssertionError): + rp.connect_var(v) def test_connect_RefPort_to_many_Vars(self): """Checks that RefPort can be connected to many Vars.""" @@ -291,6 +347,64 @@ def test_connect_RefPort_to_non_sharable_Var(self): with self.assertRaises(VarNotSharableError): rp.connect_var(v) + def test_connect_RefPort_to_InPort_OutPort(self): + """Checks connecting RefPort to an InPort or OutPort. -> TypeError""" + + # Create an InPort, OutPort, RefPort... + ip = InPort((1, 2, 3)) + op = OutPort((1, 2, 3)) + rp = RefPort((1, 2, 3)) + + # ... and connect them via connect(..) + # The type conflict should raise an TypeError + with self.assertRaises(TypeError): + rp.connect(ip) + + with self.assertRaises(TypeError): + rp.connect(op) + + # Connect them via connect_from(..) + # The type conflict should raise an TypeError + with self.assertRaises(TypeError): + rp.connect_from(ip) + + with self.assertRaises(TypeError): + rp.connect_from(op) + + def test_connect_VarPort_to_InPort_OutPort_RefPort(self): + """Checks connecting VarPort to an InPort, OutPort or RefPort. + -> TypeError (RefPort can only be connected via connect_from(..) to + VarPort.""" + + # Create an InPort, OutPort, RefPort, Var with VarPort... + ip = InPort((1, 2, 3)) + op = OutPort((1, 2, 3)) + rp = RefPort((1, 2, 3)) + v = Var((1, 2, 3)) + vp = VarPort(v) + + # ... and connect them via connect(..) + # The type conflict should raise an TypeError + with self.assertRaises(TypeError): + vp.connect(ip) + + with self.assertRaises(TypeError): + vp.connect(op) + + with self.assertRaises(TypeError): + vp.connect(rp) + + # Connect them via connect_from(..) + # The type conflict should raise an TypeError + with self.assertRaises(TypeError): + vp.connect_from(ip) + + with self.assertRaises(TypeError): + vp.connect_from(op) + + # Connect RefPort via connect_from(..) raises no error + vp.connect_from(rp) + class TestVirtualPorts(unittest.TestCase): """Contains unit tests around virtual ports. Virtual ports are derived diff --git a/tests/lava/magma/runtime/test_get_set_var.py b/tests/lava/magma/runtime/test_get_set_var.py index 02f33ee69..ba11929a4 100644 --- a/tests/lava/magma/runtime/test_get_set_var.py +++ b/tests/lava/magma/runtime/test_get_set_var.py @@ -1,3 +1,7 @@ +# Copyright (C) 2021 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + import numpy as np import unittest diff --git a/tests/lava/magma/runtime/test_ref_var_ports.py b/tests/lava/magma/runtime/test_ref_var_ports.py new file mode 100644 index 000000000..fbd62c576 --- /dev/null +++ b/tests/lava/magma/runtime/test_ref_var_ports.py @@ -0,0 +1,225 @@ +# Copyright (C) 2021 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import numpy as np +import unittest + +from lava.magma.core.decorator import implements, requires +from lava.magma.core.model.py.model import PyLoihiProcessModel +from lava.magma.core.model.py.ports import PyRefPort, PyVarPort +from lava.magma.core.model.py.type import LavaPyType +from lava.magma.core.process.ports.ports import RefPort, VarPort +from lava.magma.core.process.process import AbstractProcess +from lava.magma.core.process.variable import Var +from lava.magma.core.resources import CPU +from lava.magma.core.sync.domain import SyncDomain +from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol +from lava.magma.core.run_configs import RunConfig +from lava.magma.core.run_conditions import RunSteps + + +# A minimal process with a Var and a RefPort, VarPort +class P1(AbstractProcess): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.ref1 = RefPort(shape=(3,)) + self.var1 = Var(shape=(2,), init=17) + self.var_port_var1 = VarPort(self.var1) + + +# A minimal process with 2 Vars and a RefPort, VarPort +class P2(AbstractProcess): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.var2 = Var(shape=(3,), init=4) + self.var_port_var2 = VarPort(self.var2) + self.ref2 = RefPort(shape=(2,)) + self.var3 = Var(shape=(2,), init=1) + + +# A minimal PyProcModel implementing P1 +@implements(proc=P1, protocol=LoihiProtocol) +@requires(CPU) +class PyProcModel1(PyLoihiProcessModel): + ref1: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int) + var1: np.ndarray = LavaPyType(np.ndarray, np.int32) + var_port_var1: PyVarPort = LavaPyType(PyVarPort.VEC_DENSE, int) + + def pre_guard(self): + return True + + def run_pre_mgmt(self): + if self.current_ts > 1: + ref_data = np.array([5, 5, 5]) + self.current_ts + self.ref1.write(ref_data) + + +# A minimal PyProcModel implementing P2 +@implements(proc=P2, protocol=LoihiProtocol) +@requires(CPU) +class PyProcModel2(PyLoihiProcessModel): + ref2: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int) + var2: np.ndarray = LavaPyType(np.ndarray, np.int32) + var_port_var2: PyVarPort = LavaPyType(PyVarPort.VEC_DENSE, int) + var3: np.ndarray = LavaPyType(np.ndarray, np.int32) + + def pre_guard(self): + return True + + def run_pre_mgmt(self): + if self.current_ts > 1: + self.var3 = self.ref2.read() + + +# A simple RunConfig selecting always the first found process model +class MyRunCfg(RunConfig): + def select(self, proc, proc_models): + return proc_models[0] + + +class TestRefVarPorts(unittest.TestCase): + def test_unconnected_Ref_Var_ports(self): + """RefPorts and VarPorts defined in ProcessModels, but not connected + should not lead to an error.""" + sender = P1() + + # No connections are made + + simple_sync_domain = SyncDomain("simple", LoihiProtocol(), + [sender]) + + # The process should compile and run without error (not doing anything) + sender.run(RunSteps(num_steps=3, blocking=True), + MyRunCfg(custom_sync_domains=[simple_sync_domain])) + sender.stop() + + def test_explicit_Ref_Var_port_write(self): + """Tests the connection of a RefPort to an explicitly created VarPort. + The RefPort sends data after the first time step to the VarPort, + starting with (5 + current time step) = 7). The initial value of the + var is 4. We read out the value after each time step.""" + + sender = P1() + recv = P2() + + # Connect RefPort with explicit VarPort + sender.ref1.connect(recv.var_port_var2) + + simple_sync_domain = SyncDomain("simple", LoihiProtocol(), + [sender, recv]) + + # First time step, no data is sent + sender.run(RunSteps(num_steps=1, blocking=True), + MyRunCfg(custom_sync_domains=[simple_sync_domain])) + # Initial value is expected + self.assertTrue(np.all(recv.var2.get() == np.array([4., 4., 4.]))) + # Second time step, data is sent (7) + sender.run(RunSteps(num_steps=1, blocking=True), + MyRunCfg(custom_sync_domains=[simple_sync_domain])) + self.assertTrue(np.all(recv.var2.get() == np.array([7., 7., 7.]))) + # Third time step, data is sent (8) + sender.run(RunSteps(num_steps=1, blocking=True), + MyRunCfg(custom_sync_domains=[simple_sync_domain])) + self.assertTrue(np.all(recv.var2.get() == np.array([8., 8., 8.]))) + # Fourth time step, data is sent (9) + sender.run(RunSteps(num_steps=1, blocking=True), + MyRunCfg(custom_sync_domains=[simple_sync_domain])) + self.assertTrue(np.all(recv.var2.get() == np.array([9., 9., 9.]))) + sender.stop() + + def test_implicit_Ref_Var_port_write(self): + """Tests the connection of a RefPort to an implicitly created VarPort. + The RefPort sends data after the first time step to the VarPort, + starting with (5 + current time step) = 7). The initial value of the + var is 4. We read out the value after each time step.""" + + sender = P1() + recv = P2() + + # Connect RefPort with Var using an implicit VarPort + sender.ref1.connect_var(recv.var2) + + simple_sync_domain = SyncDomain("simple", LoihiProtocol(), + [sender, recv]) + + # First time step, no data is sent + sender.run(RunSteps(num_steps=1, blocking=True), + MyRunCfg(custom_sync_domains=[simple_sync_domain])) + # Initial value is expected + self.assertTrue(np.all(recv.var2.get() == np.array([4., 4., 4.]))) + # Second time step, data is sent (7) + sender.run(RunSteps(num_steps=1, blocking=True), + MyRunCfg(custom_sync_domains=[simple_sync_domain])) + self.assertTrue(np.all(recv.var2.get() == np.array([7., 7., 7.]))) + # Third time step, data is sent (8) + sender.run(RunSteps(num_steps=1, blocking=True), + MyRunCfg(custom_sync_domains=[simple_sync_domain])) + self.assertTrue(np.all(recv.var2.get() == np.array([8., 8., 8.]))) + # Fourth time step, data is sent (9) + sender.run(RunSteps(num_steps=1, blocking=True), + MyRunCfg(custom_sync_domains=[simple_sync_domain])) + self.assertTrue(np.all(recv.var2.get() == np.array([9., 9., 9.]))) + sender.stop() + + def test_explicit_Ref_Var_port_read(self): + """Tests the connection of a RefPort to an explicitly created VarPort. + The RefPort "ref_read" reads data after the first time step of the + VarPort "var_port_read" which has the value of the Var "v" (= 17) and + writes this value into the Var "var_read". The initial value of the var + "var_read" is 1. At time step 2 the value of "var_read" is 17.""" + + sender = P1() + recv = P2() + + # Connect RefPort with explicit VarPort + recv.ref2.connect(sender.var_port_var1) + + simple_sync_domain = SyncDomain("simple", LoihiProtocol(), + [sender, recv]) + + # First time step, no read + sender.run(RunSteps(num_steps=1, blocking=True), + MyRunCfg(custom_sync_domains=[simple_sync_domain])) + # Initial value (1) is expected + self.assertTrue(np.all(recv.var3.get() == np.array([1., 1.]))) + # Second time step, the RefPort read from the VarPort and wrote the + # Result in "var_read" (= 17) + sender.run(RunSteps(num_steps=1, blocking=True), + MyRunCfg(custom_sync_domains=[simple_sync_domain])) + self.assertTrue( + np.all(recv.var3.get() == np.array([17., 17.]))) + sender.stop() + + def test_implicit_Ref_Var_port_read(self): + """Tests the connection of a RefPort to an implicitly created VarPort. + The RefPort "ref_read" reads data after the first time step of the + of the Var "v" (= 17) using an implicit VarPort and writes this value + into the Var "var_read". The initial value of the var "var_read" is 1. + At time step 2 the value of "var_read" is 17.""" + + sender = P1() + recv = P2() + + # Connect RefPort with explicit VarPort + recv.ref2.connect_var(sender.var1) + + simple_sync_domain = SyncDomain("simple", LoihiProtocol(), + [sender, recv]) + + # First time step, no read + recv.run(RunSteps(num_steps=1, blocking=True), + MyRunCfg(custom_sync_domains=[simple_sync_domain])) + # Initial value (1) is expected + self.assertTrue(np.all(recv.var3.get() == np.array([1., 1.]))) + # Second time step, the RefPort read from the VarPort and wrote the + # Result in "var_read" (= 17) + recv.run(RunSteps(num_steps=1, blocking=True), + MyRunCfg(custom_sync_domains=[simple_sync_domain])) + self.assertTrue( + np.all(recv.var3.get() == np.array([17., 17.]))) + recv.stop() + + +if __name__ == '__main__': + unittest.main()