diff --git a/src/lava/magma/compiler/builders/builder.py b/src/lava/magma/compiler/builders/builder.py index 22d02bf6b..c9aa45fc7 100644 --- a/src/lava/magma/compiler/builders/builder.py +++ b/src/lava/magma/compiler/builders/builder.py @@ -370,7 +370,17 @@ def build(self): csp_ports = self.csp_ports[name] if not isinstance(csp_ports, list): csp_ports = [csp_ports] - port = port_cls(csp_ports, pm, p.shape, lt.d_type) + + # TODO (MR): This is probably just a temporary hack until the + # interface of PyOutPorts has been adjusted. + if issubclass(port_cls, PyInPort): + port = port_cls(csp_ports, pm, p.shape, lt.d_type, + p.transform_funcs) + elif issubclass(port_cls, PyOutPort): + port = port_cls(csp_ports, pm, p.shape, lt.d_type) + else: + raise AssertionError("port_cls must be of type PyInPort or " + "PyOutPort") # Create dynamic PyPort attribute on ProcModel setattr(pm, name, port) diff --git a/src/lava/magma/compiler/compiler.py b/src/lava/magma/compiler/compiler.py index da1eb22d4..14a513e8e 100644 --- a/src/lava/magma/compiler/compiler.py +++ b/src/lava/magma/compiler/compiler.py @@ -1,6 +1,7 @@ # Copyright (C) 2021 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause # See: https://spdx.org/licenses/ + import logging import importlib import importlib.util as import_utils @@ -38,7 +39,7 @@ from lava.magma.core.model.py.ports import RefVarTypeMapping from lava.magma.core.model.sub.model import AbstractSubProcessModel from lava.magma.core.process.ports.ports import AbstractPort, VarPort, \ - ImplicitVarPort, RefPort + ImplicitVarPort, InPort, RefPort from lava.magma.core.process.process import AbstractProcess from lava.magma.core.resources import ( CPU, @@ -346,12 +347,25 @@ def _compile_proc_models( # and Ports v = [VarInitializer(v.name, v.shape, v.init, v.id) for v in p.vars] - ports = (list(p.in_ports) + list(p.out_ports)) - ports = [PortInitializer(pt.name, + + ports = [] + for pt in (list(p.in_ports) + list(p.out_ports)): + # For all InPorts that receive input from + # virtual ports... + transform_funcs = None + if isinstance(pt, InPort): + # ... extract a function pointer to the + # transformation function of each virtual port. + transform_funcs = \ + [vp.get_transform_func() + for vp in pt.get_incoming_virtual_ports()] + pi = PortInitializer(pt.name, pt.shape, self._get_port_dtype(pt, pm), pt.__class__.__name__, - pp_ch_size) for pt in ports] + pp_ch_size, + transform_funcs) + ports.append(pi) # Create RefPort (also use PortInitializers) ref_ports = list(p.ref_ports) ref_ports = [ diff --git a/src/lava/magma/compiler/utils.py b/src/lava/magma/compiler/utils.py index 14da90baf..27f4775fb 100644 --- a/src/lava/magma/compiler/utils.py +++ b/src/lava/magma/compiler/utils.py @@ -1,4 +1,5 @@ import typing as ty +import functools as ft from dataclasses import dataclass @@ -17,6 +18,7 @@ class PortInitializer: d_type: type port_type: str size: int + transform_funcs: ty.List[ft.partial] = None # check if can be a subclass of PortInitializer diff --git a/src/lava/magma/core/model/py/ports.py b/src/lava/magma/core/model/py/ports.py index faf0e5103..6a3c6ffa2 100644 --- a/src/lava/magma/core/model/py/ports.py +++ b/src/lava/magma/core/model/py/ports.py @@ -1,6 +1,7 @@ # Copyright (C) 2021 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause # See: https://spdx.org/licenses/ + import typing as ty from abc import abstractmethod import functools as ft @@ -137,6 +138,16 @@ class PyInPort(AbstractPyIOPort): SCALAR_DENSE: ty.Type["PyInPortScalarDense"] = None SCALAR_SPARSE: ty.Type["PyInPortScalarSparse"] = None + def __init__(self, + csp_ports: ty.List[AbstractCspPort], + process_model: AbstractProcessModel, + shape: ty.Tuple[int, ...], + d_type: type, + transform_funcs: ty.Optional[ty.List[ft.partial]] = None): + + self._transform_funcs = transform_funcs + super().__init__(csp_ports, process_model, shape, d_type) + @abstractmethod def recv(self): """Abstract method to receive data (vectors/scalars) sent from connected @@ -182,6 +193,25 @@ def probe(self) -> bool: True, ) + def _transform(self, recv_data: np.array) -> np.array: + """Applies all transformation function pointers to the input data. + + Parameters + ---------- + recv_data : numpy.ndarray + data received on the port that shall be transformed + + Returns + ------- + recv_data : numpy.ndarray + received data, transformed by the incoming virtual ports + """ + if self._transform_funcs: + # apply all transformation functions to the received data + for f in self._transform_funcs: + recv_data = f(recv_data) + return recv_data + class PyInPortVectorDense(PyInPort): """Python implementation of PyInPort for dense vector data.""" @@ -199,7 +229,7 @@ def recv(self) -> np.ndarray: fashion. """ return ft.reduce( - lambda acc, csp_port: acc + csp_port.recv(), + lambda acc, csp_port: acc + self._transform(csp_port.recv()), self.csp_ports, np.zeros(self._shape, self._d_type), ) diff --git a/src/lava/magma/core/process/ports/exceptions.py b/src/lava/magma/core/process/ports/exceptions.py index 1c1a96de0..ed65ad99d 100644 --- a/src/lava/magma/core/process/ports/exceptions.py +++ b/src/lava/magma/core/process/ports/exceptions.py @@ -34,6 +34,48 @@ def __init__(self, shapes, axis): super().__init__(self, msg) +class ConcatIndexError(Exception): + """Raised when the axis over which ports should be concatenated is out of + bounds.""" + + def __init__(self, shape: ty.Tuple[int], axis: int): + msg = ( + "Axis {} is out of bounds for given shape {}.".format(axis, shape) + ) + super().__init__(self, msg) + + +class TransposeShapeError(Exception): + """Raised when transpose axes is incompatible with old shape dimension.""" + + def __init__( + self, old_shape: ty.Tuple, axes: ty.Union[ty.Tuple, ty.List] + ) -> None: + msg = ( + "Cannot transpose 'old_shape'={} with permutation 'axes={}. " + "Total number of dimensions must not change during " + "reshaping.".format(old_shape, axes) + ) + super().__init__(msg) + + +class TransposeIndexError(Exception): + """Raised when indices in transpose axes are out of bounds for the old + shape dimension.""" + + def __init__( + self, + old_shape: ty.Tuple, + axes: ty.Union[ty.Tuple, ty.List], + wrong_index + ) -> None: + msg = ( + f"Cannot transpose 'old_shape'={old_shape} with permutation" + f"'axes'={axes}. The index {wrong_index} is out of bounds." + ) + super().__init__(msg) + + class VarNotSharableError(Exception): """Raised when an attempt is made to connect a RefPort or VarPort to a non-sharable Var.""" diff --git a/src/lava/magma/core/process/ports/ports.py b/src/lava/magma/core/process/ports/ports.py index 4d4a651a5..dfa82c301 100644 --- a/src/lava/magma/core/process/ports/ports.py +++ b/src/lava/magma/core/process/ports/ports.py @@ -1,8 +1,12 @@ # Copyright (C) 2021 Intel Corporation -# SPDX-License-Identifier: BSD-3-Clause +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + import typing as ty from abc import ABC, abstractmethod import math +import numpy as np +import functools as ft from lava.magma.core.process.interfaces import AbstractProcessMember import lava.magma.core.process.ports.exceptions as pe @@ -141,6 +145,39 @@ def get_src_ports(self, _include_self=False) -> ty.List["AbstractPort"]: ports += p.get_src_ports(True) return ports + def get_incoming_virtual_ports(self) -> ty.List["AbstractVirtualPort"]: + """Returns the list of all incoming virtual ports in order from + source to the current port. + + Returns + ------- + virtual_ports : list(AbstractVirtualPorts) + the list of all incoming virtual ports, sorted from source to + destination port + """ + if len(self.in_connections) == 0: + return [] + else: + virtual_ports = [] + num_virtual_ports = 0 + for p in self.in_connections: + virtual_ports += p.get_incoming_virtual_ports() + if isinstance(p, AbstractVirtualPort): + # TODO (MR): ConcatPorts are not yet supported by the + # compiler - until then, an exception is raised. + if isinstance(p, ConcatPort): + raise NotImplementedError("ConcatPorts are not yet " + "supported.") + + virtual_ports.append(p) + num_virtual_ports += 1 + + if num_virtual_ports > 1: + raise NotImplementedError("Joining multiple virtual ports is " + "not yet supported.") + + return virtual_ports + def get_dst_ports(self, _include_self=False) -> ty.List["AbstractPort"]: """Returns the list of all destination ports that this port connects to either directly or indirectly (through other ports).""" @@ -165,6 +202,12 @@ def reshape(self, new_shape: ty.Tuple) -> "ReshapePort": :param new_shape: New shape of port. Number of total elements must not change. """ + # TODO (MR): Implement for other types of Ports + if not (isinstance(self, OutPort) + or isinstance(self, AbstractVirtualPort)): + raise NotImplementedError("reshape/flatten are only implemented " + "for OutPorts") + if self.size != math.prod(new_shape): raise pe.ReshapeError(self.shape, new_shape) @@ -204,6 +247,48 @@ def concat_with( self._validate_ports(ports, port_type, assert_same_shape=False) return ConcatPort(ports, axis) + def transpose( + self, + axes: ty.Optional[ty.Union[ty.Tuple, ty.List]] = None + ) -> "TransposePort": + """Permutes the tensor dimension of this port by deriving and returning + a new virtual TransposePort the new permuted dimension. This implies + that the resulting TransposePort can only be forward connected to + another port. + + Parameters + ---------- + :param axes: Order of permutation. Number of total elements and number + of dimensions must not change. + """ + # TODO (MR): Implement for other types of Ports + if not (isinstance(self, OutPort) + or isinstance(self, AbstractVirtualPort)): + raise NotImplementedError("transpose is only implemented for " + "OutPorts") + + if axes is None: + axes = tuple(reversed(range(len(self.shape)))) + else: + if len(self.shape) != len(axes): + raise pe.TransposeShapeError(self.shape, axes) + + # Check that none of the given axes are out of bounds for the + # shape of the parent port. + for idx in axes: + # Compute the positive index irrespective of the sign of 'idx' + idx_positive = len(self.shape) + idx if idx < 0 else idx + # Make sure the positive index is not out of bounds + if idx_positive < 0 or idx_positive >= len(self.shape): + raise pe.TransposeIndexError(self.shape, axes, idx) + + new_shape = tuple([self.shape[i] for i in axes]) + transpose_port = TransposePort(new_shape, axes) + self._connect_forward( + [transpose_port], AbstractPort, assert_same_shape=False + ) + return transpose_port + class AbstractIOPort(AbstractPort): """Abstract base class for InPorts and OutPorts. @@ -566,17 +651,14 @@ class ImplicitVarPort(VarPort): pass -class AbstractVirtualPort(ABC): +class AbstractVirtualPort(AbstractPort): """Abstract base class interface for any type of port that merely serves - to transforms the properties of a user-defined port. - Needs no implementation because this class purely serves as a - type-identifier.""" + to transform the properties of a user-defined port.""" @property - @abstractmethod def _parent_port(self): """Must return parent port that this VirtualPort was derived from.""" - pass + return self.get_src_ports()[0] @property def process(self): @@ -584,24 +666,8 @@ def process(self): derived from.""" return self._parent_port.process - -# ToDo: (AW) ReshapePort.connect(..) could be consolidated with -# ConcatPort.connect(..) -class ReshapePort(AbstractPort, AbstractVirtualPort): - """A ReshapePort is a virtual port that allows to change the shape of a - port before connecting to another port. - It is used by the compiler to map the indices of the underlying - tensor-valued data array from the derived to the new shape.""" - - def __init__(self, shape: ty.Tuple): - AbstractPort.__init__(self, shape) - - @property - def _parent_port(self) -> AbstractPort: - return self.in_connections[0] - def connect(self, ports: ty.Union["AbstractPort", ty.List["AbstractPort"]]): - """Connects this ReshapePort to other port(s). + """Connects this virtual port to other port(s). Parameters ---------- @@ -626,8 +692,25 @@ def connect(self, ports: ty.Union["AbstractPort", ty.List["AbstractPort"]]): # Connect to ports self._connect_forward(to_list(ports), port_type) + @abstractmethod + def get_transform_func(self) -> ft.partial: + pass + + +class ReshapePort(AbstractVirtualPort): + """A ReshapePort is a virtual port that allows to change the shape of a + port before connecting to another port. + It is used by the compiler to map the indices of the underlying + tensor-valued data array from the derived to the new shape.""" + + def __init__(self, new_shape: ty.Tuple): + AbstractPort.__init__(self, new_shape) + + def get_transform_func(self) -> ft.partial: + return ft.partial(np.reshape, newshape=self.shape) -class ConcatPort(AbstractPort, AbstractVirtualPort): + +class ConcatPort(AbstractVirtualPort): """A ConcatPort is a virtual port that allows to concatenate multiple ports along given axis into a new port before connecting to another port. The shape of all concatenated ports outside of the concatenation @@ -651,6 +734,9 @@ def _get_new_shape(ports: ty.List[AbstractPort], axis): shapes_ex_axis = [] shapes_incompatible = False for shape in concat_shapes: + if axis >= len(shape): + raise pe.ConcatIndexError(shape, axis) + # Compute total size along concatenation axis total_size += shape[axis] # Extract shape dimensions other than concatenation axis @@ -665,40 +751,13 @@ def _get_new_shape(ports: ty.List[AbstractPort], axis): new_shape = shapes_ex_axis[0] return new_shape[:axis] + (total_size,) + new_shape[axis:] - @property - def _parent_port(self) -> AbstractPort: - return self.in_connections[0] - - def connect(self, ports: ty.Union["AbstractPort", ty.List["AbstractPort"]]): - """Connects this ConcatPort to other port(s) - - Parameters - ---------- - :param ports: The port(s) to connect to. Connections from an IOPort - to a RVPort and vice versa are not allowed. - """ - # Determine allows port_type - if isinstance(self._parent_port, OutPort): - # If OutPort, only allow other IO ports - port_type = AbstractIOPort - elif isinstance(self._parent_port, InPort): - # If InPort, only allow other InPorts - port_type = InPort - elif isinstance(self._parent_port, RefPort): - # If RefPort, only allow other Ref- or VarPorts - port_type = AbstractRVPort - elif isinstance(self._parent_port, VarPort): - # If VarPort, only allow other VarPorts - port_type = VarPort - else: - raise TypeError("Illegal parent port.") - # Connect to ports - self._connect_forward(to_list(ports), port_type) + def get_transform_func(self) -> ft.partial: + # TODO (MR): not yet implemented + raise NotImplementedError() -# ToDo: TBD... -class PermutePort(AbstractPort, AbstractVirtualPort): - """A PermutePort is a virtual port that allows to permute the dimensions +class TransposePort(AbstractVirtualPort): + """A TransposePort is a virtual port that allows to permute the dimensions of a port before connecting to another port. It is used by the compiler to map the indices of the underlying tensor-valued data array from the derived to the new shape. @@ -706,14 +765,21 @@ class PermutePort(AbstractPort, AbstractVirtualPort): Example: out_port = OutPort((2, 4, 3)) in_port = InPort((3, 2, 4)) - out_port.permute([3, 1, 2]).connect(in_port) + out_port.transpose([3, 1, 2]).connect(in_port) """ - pass + def __init__(self, + new_shape: ty.Tuple[int, ...], + axes: ty.Tuple[int, ...]): + self._axes = axes + AbstractPort.__init__(self, new_shape) + + def get_transform_func(self) -> ft.partial: + return ft.partial(np.transpose, axes=self._axes) # ToDo: TBD... -class ReIndexPort(AbstractPort, AbstractVirtualPort): +class ReIndexPort(AbstractVirtualPort): """A ReIndexPort is a virtual port that allows to re-index the elements of a port before connecting to another port. It is used by the compiler to map the indices of the underlying diff --git a/tests/lava/magma/core/model/py/test_ports.py b/tests/lava/magma/core/model/py/test_ports.py index e34c799db..1a59e7dd4 100644 --- a/tests/lava/magma/core/model/py/test_ports.py +++ b/tests/lava/magma/core/model/py/test_ports.py @@ -61,7 +61,7 @@ def probe_test_routine(self, cls): # Create PyInPort with current implementation recv_py_port: PyInPort = \ cls([recv_csp_port_1, recv_csp_port_2], None, data.shape, - data.dtype) + data.dtype, None) recv_py_port.start() send_py_port_1.start() diff --git a/tests/lava/magma/core/process/ports/__init__.py b/tests/lava/magma/core/process/ports/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/lava/magma/core/process/test_ports.py b/tests/lava/magma/core/process/ports/test_ports.py similarity index 82% rename from tests/lava/magma/core/process/test_ports.py rename to tests/lava/magma/core/process/ports/test_ports.py index 79a982ffe..0e6fe26c4 100644 --- a/tests/lava/magma/core/process/test_ports.py +++ b/tests/lava/magma/core/process/ports/test_ports.py @@ -1,22 +1,27 @@ # Copyright (C) 2021 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause # See: https://spdx.org/licenses/ + import unittest from lava.magma.core.process.process import AbstractProcess +from lava.magma.core.process.variable import Var from lava.magma.core.process.ports.ports import ( InPort, OutPort, RefPort, VarPort, ConcatPort, + TransposePort, ) from lava.magma.core.process.ports.exceptions import ( ReshapeError, DuplicateConnectionError, ConcatShapeError, + ConcatIndexError, + TransposeShapeError, + TransposeIndexError, VarNotSharableError, ) -from lava.magma.core.process.variable import Var class TestPortInitialization(unittest.TestCase): @@ -410,30 +415,51 @@ def test_connect_VarPort_to_InPort_OutPort_RefPort(self): class TestVirtualPorts(unittest.TestCase): """Contains unit tests around virtual ports. Virtual ports are derived - ports that are not directly created by developer as part of process + ports that are not directly created by the developer as part of process definition but which serve to somehow transform the properties of a - deverloper-defined port.""" + developer-defined port.""" def test_reshape(self): """Checks reshaping of a port.""" # Create some ports op = OutPort((1, 2, 3)) - ip1 = InPort((3, 2, 1)) - ip2 = InPort((3, 2, 10)) + ip = InPort((3, 2, 1)) # Using reshape(..), ports with different shape can be connected as # long as total number of elements does not change - op.reshape((3, 2, 1)).connect(ip1) + op.reshape((3, 2, 1)).connect(ip) # We can still find destination and source connection even with # virtual ports in the chain - self.assertEqual(op.get_dst_ports(), [ip1]) - self.assertEqual(ip1.get_src_ports(), [op]) + self.assertEqual(op.get_dst_ports(), [ip]) + self.assertEqual(ip.get_src_ports(), [op]) + + def test_reshape_with_wrong_number_of_elements_raises_exception(self): + """Checks whether an exception is raised when the number of elements + in the specified shape is different from the number of elements in + the source shape.""" - # However, ports with a different number of elements cannot be connected with self.assertRaises(ReshapeError): - op.reshape((3, 2, 10)).connect(ip2) + OutPort((1, 2, 3)).reshape((1, 2, 2)) + + def test_flatten(self): + """Checks flattening of a port.""" + + op = OutPort((1, 2, 3)) + ip = InPort((6,)) + + # Flatten the shape of the port. + fp = op.flatten() + self.assertEqual(fp.shape, (6,)) + + # This enables connecting to an input port with a flattened shape. + fp.connect(ip) + + # We can still find destination and source connection even with + # virtual ports in the chain + self.assertEqual(op.get_dst_ports(), [ip]) + self.assertEqual(ip.get_src_ports(), [op]) def test_concat(self): """Checks concatenation of ports.""" @@ -469,7 +495,7 @@ def test_concat(self): # (2, 3, 1) + (2, 3, 1) concatenated along axis 1 results in (2, 6, 1) self.assertEqual(ip2.in_connections[0].shape, (2, 6, 1)) - def test_concat_with_incompatible_ports(self): + def test_concat_with_incompatible_shapes_raises_exception(self): """Checks that incompatible ports cannot be concatenated.""" # Create ports with incompatible shapes @@ -481,11 +507,69 @@ def test_concat_with_incompatible_ports(self): with self.assertRaises(ConcatShapeError): op1.concat_with(op2, axis=0) - # Create another port with incompatible type + def test_concat_with_incompatible_type_raises_exception(self): + """Checks that incompatible port types raise an exception.""" + + op = OutPort((2, 3, 1)) ip = InPort((2, 3, 1)) # This will fail because concatenated ports must be of same type with self.assertRaises(AssertionError): - op1.concat_with(ip, axis=0) + op.concat_with(ip, axis=0) + + def test_concat_with_axis_out_of_bounds_raises_exception(self): + """Checks whether an exception is raised when the specified axis is + out of bounds.""" + + op1 = OutPort((2, 3, 1)) + op2 = OutPort((2, 3, 1)) + with self.assertRaises(ConcatIndexError): + op1.concat_with(op2, axis=3) + + def test_transpose(self): + """Checks transposing of ports.""" + + op = OutPort((1, 2, 3)) + ip = InPort((2, 1, 3)) + + tp = op.transpose(axes=(1, 0, 2)) + # The return value is a virtual TransposePort ... + self.assertIsInstance(tp, TransposePort) + # ... which needs to have the same dimensions ... + self.assertEqual(tp.shape, (2, 1, 3)) + # ... as the port we want to connect to. + tp.connect(ip) + # Finally, the virtual TransposePort is the input connection of the ip + self.assertEqual(tp, ip.in_connections[0]) + + # Again, we can still find destination and source ports through a + # chain of ports containing virtual ports + self.assertEqual(op.get_dst_ports(), [ip]) + self.assertEqual(ip.get_src_ports(), [op]) + + def test_transpose_without_specified_axes(self): + """Checks whether transpose reverses the shape-elements when no + 'axes' argument is given.""" + + op = OutPort((1, 2, 3)) + tp = op.transpose() + self.assertEqual(tp.shape, (3, 2, 1)) + + def test_transpose_incompatible_axes_length_raises_exception(self): + """Checks whether an exception is raised when the number of elements + in the specified 'axes' argument differs from the number of elements + of the parent port.""" + + op = OutPort((1, 2, 3)) + with self.assertRaises(TransposeShapeError): + op.transpose(axes=(0, 0, 1, 2)) + + def test_transpose_incompatible_axes_indices_raises_exception(self): + """Checks whether an exception is raised when the indices specified + in the 'axes' argument are out of bounds for the parent port.""" + + op = OutPort((1, 2, 3)) + with self.assertRaises(TransposeIndexError): + op.transpose(axes=(0, 1, 3)) if __name__ == "__main__": diff --git a/tests/lava/magma/core/process/ports/test_virtual_ports_in_process.py b/tests/lava/magma/core/process/ports/test_virtual_ports_in_process.py new file mode 100644 index 000000000..474146fd5 --- /dev/null +++ b/tests/lava/magma/core/process/ports/test_virtual_ports_in_process.py @@ -0,0 +1,320 @@ +# Copyright (C) 2021 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import typing as ty +import unittest +import numpy as np +import functools as ft + +from lava.magma.core.decorator import requires, tag, implements +from lava.magma.core.model.py.model import PyLoihiProcessModel +from lava.magma.core.model.sub.model import AbstractSubProcessModel +from lava.magma.core.model.py.type import LavaPyType +from lava.magma.core.process.variable import Var +from lava.magma.core.process.process import AbstractProcess +from lava.magma.core.resources import CPU +from lava.magma.core.run_conditions import RunSteps +from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol +from lava.magma.core.run_configs import Loihi1SimCfg +from lava.magma.core.model.py.ports import ( + PyInPort, + PyOutPort +) +from lava.magma.core.process.ports.ports import ( + AbstractPort, + AbstractVirtualPort, + InPort, + OutPort +) + + +np.random.seed(7739) + + +class MockVirtualPort(AbstractVirtualPort, AbstractPort): + """A mock-up of a virtual port that reshapes the input.""" + def __init__(self, new_shape: ty.Tuple): + AbstractPort.__init__(self, new_shape) + + def get_transform_func(self) -> ft.partial: + return ft.partial(np.reshape, newshape=self.shape) + + +class TestVirtualPortNetworkTopologies(unittest.TestCase): + """Tests different network topologies that include virtual ports using a + dummy virtual port as a stand-in for all types of virtual ports.""" + + def setUp(self) -> None: + self.num_steps = 1 + self.shape = (4, 3, 2) + self.new_shape = (12, 2) + self.input_data = np.random.randint(256, size=self.shape) + + def test_virtual_ports_between_hierarchical_processes(self) -> None: + """Tests a virtual port between an OutPort of a hierarchical Process + and an InPort of another hierarchical Process.""" + + source = HOutPortProcess(data=self.input_data) + sink = HInPortProcess(shape=self.new_shape) + + virtual_port = MockVirtualPort(new_shape=self.new_shape) + + source.out_port._connect_forward( + [virtual_port], AbstractPort, assert_same_shape=False + ) + virtual_port.connect(sink.in_port) + + sink.run(condition=RunSteps(num_steps=self.num_steps), + run_cfg=Loihi1SimCfg(select_tag='floating_pt')) + output = sink.data.get() + sink.stop() + + expected = self.input_data.reshape(self.new_shape) + self.assertTrue( + np.all(output == expected), + f'Input and output do not match.\n' + f'{output[output!=expected]=}\n' + f'{expected[output!=expected] =}\n' + ) + + def test_chaining_multiple_virtual_ports(self) -> None: + """Tests whether two virtual ReshapePorts can be chained through the + flatten() method.""" + + source = OutPortProcess(data=self.input_data) + shape_final = (int(np.prod(self.shape)),) + sink = InPortProcess(shape=shape_final) + + virtual_port1 = MockVirtualPort(new_shape=self.new_shape) + virtual_port2 = MockVirtualPort(new_shape=shape_final) + + source.out_port._connect_forward( + [virtual_port1], AbstractPort, assert_same_shape=False + ) + virtual_port1._connect_forward( + [virtual_port2], AbstractPort, assert_same_shape=False + ) + virtual_port2.connect(sink.in_port) + + sink.run(condition=RunSteps(num_steps=self.num_steps), + run_cfg=Loihi1SimCfg(select_tag='floating_pt')) + output = sink.data.get() + sink.stop() + + expected = self.input_data.ravel() + self.assertTrue( + np.all(output == expected), + f'Input and output do not match.\n' + f'{output[output!=expected]=}\n' + f'{expected[output!=expected] =}\n' + ) + + def test_joining_virtual_ports_throws_exception(self) -> None: + """Tests whether joining two virtual ports throws an exception.""" + + source1 = OutPortProcess(data=self.input_data) + source2 = OutPortProcess(data=self.input_data) + sink = InPortProcess(shape=self.new_shape) + + virtual_port1 = MockVirtualPort(new_shape=self.new_shape) + virtual_port2 = MockVirtualPort(new_shape=self.new_shape) + + source1.out_port._connect_forward( + [virtual_port1], AbstractPort, assert_same_shape=False + ) + source2.out_port._connect_forward( + [virtual_port2], AbstractPort, assert_same_shape=False + ) + + virtual_port1.connect(sink.in_port) + virtual_port2.connect(sink.in_port) + + with self.assertRaises(NotImplementedError): + sink.run(condition=RunSteps(num_steps=self.num_steps), + run_cfg=Loihi1SimCfg(select_tag='floating_pt')) + + +class TestTransposePort(unittest.TestCase): + """Tests virtual TransposePorts on Processes that are executed.""" + + def setUp(self) -> None: + self.num_steps = 1 + self.axes = (2, 0, 1) + self.axes_reverse = list(self.axes) + for idx, ax in enumerate(self.axes): + self.axes_reverse[ax] = idx + self.axes_reverse = tuple(self.axes_reverse) + self.shape = (4, 3, 2) + self.shape_transposed = tuple(self.shape[i] for i in self.axes) + self.shape_transposed_reverse = \ + tuple(self.shape[i] for i in self.axes_reverse) + self.input_data = np.random.randint(256, size=self.shape) + + def test_transpose_outport_to_inport(self) -> None: + """Tests a virtual TransposePort between an OutPort and an InPort.""" + + source = OutPortProcess(data=self.input_data) + sink = InPortProcess(shape=self.shape_transposed) + + source.out_port.transpose(axes=self.axes).connect(sink.in_port) + + sink.run(condition=RunSteps(num_steps=self.num_steps), + run_cfg=Loihi1SimCfg(select_tag='floating_pt')) + output = sink.data.get() + sink.stop() + + expected = self.input_data.transpose(self.axes) + self.assertTrue( + np.all(output == expected), + f'Input and output do not match.\n' + f'{output[output!=expected]=}\n' + f'{expected[output!=expected] =}\n' + ) + + +class TestReshapePort(unittest.TestCase): + """Tests virtual ReshapePorts on Processes that are executed.""" + + def setUp(self) -> None: + self.num_steps = 1 + self.shape = (4, 3, 2) + self.shape_reshaped = (12, 2) + self.input_data = np.random.randint(256, size=self.shape) + + def test_reshape_outport_to_inport(self) -> None: + """Tests a virtual ReshapePort between an OutPort and an InPort.""" + + source = OutPortProcess(data=self.input_data) + sink = InPortProcess(shape=self.shape_reshaped) + + source.out_port.reshape(new_shape=self.shape_reshaped).connect( + sink.in_port) + + sink.run(condition=RunSteps(num_steps=self.num_steps), + run_cfg=Loihi1SimCfg(select_tag='floating_pt')) + output = sink.data.get() + sink.stop() + + expected = self.input_data.reshape(self.shape_reshaped) + self.assertTrue( + np.all(output == expected), + f'Input and output do not match.\n' + f'{output[output!=expected]=}\n' + f'{expected[output!=expected] =}\n' + ) + + +class TestFlattenPort(unittest.TestCase): + """Tests virtual ReshapePorts, created by the flatten() method, + on Processes that are executed.""" + + def setUp(self) -> None: + self.num_steps = 1 + self.shape = (4, 3, 2) + self.shape_reshaped = (24,) + self.input_data = np.random.randint(256, size=self.shape) + + def test_flatten_outport_to_inport(self) -> None: + """Tests a virtual ReshapePort with flatten() between an OutPort and an + InPort.""" + + source = OutPortProcess(data=self.input_data) + sink = InPortProcess(shape=self.shape_reshaped) + + source.out_port.flatten().connect(sink.in_port) + + sink.run(condition=RunSteps(num_steps=self.num_steps), + run_cfg=Loihi1SimCfg(select_tag='floating_pt')) + output = sink.data.get() + sink.stop() + + expected = self.input_data.ravel() + self.assertTrue( + np.all(output == expected), + f'Input and output do not match.\n' + f'{output[output!=expected]=}\n' + f'{expected[output!=expected] =}\n' + ) + + +# A minimal Process with an OutPort +class OutPortProcess(AbstractProcess): + def __init__(self, data: np.ndarray) -> None: + super().__init__(data=data) + self.data = Var(shape=data.shape, init=data) + self.out_port = OutPort(shape=data.shape) + + +# A minimal Process with an InPort +class InPortProcess(AbstractProcess): + def __init__(self, shape: ty.Tuple[int, ...]) -> None: + super().__init__(shape=shape) + self.data = Var(shape=shape, init=np.zeros(shape)) + self.in_port = InPort(shape=shape) + + +# A minimal hierarchical Process with an OutPort +class HOutPortProcess(AbstractProcess): + def __init__(self, data: np.ndarray) -> None: + super().__init__(data=data) + self.data = Var(shape=data.shape, init=data) + self.out_port = OutPort(shape=data.shape) + self.proc_params['data'] = data + + +# A minimal hierarchical Process with an InPort and a Var +class HInPortProcess(AbstractProcess): + def __init__(self, shape: ty.Tuple[int, ...]) -> None: + super().__init__(shape=shape) + self.data = Var(shape=shape, init=np.zeros(shape)) + self.in_port = InPort(shape=shape) + self.proc_params['shape'] = shape + + +# A minimal PyProcModel implementing OutPortProcess +@implements(proc=OutPortProcess, protocol=LoihiProtocol) +@requires(CPU) +@tag('floating_pt') +class PyOutPortProcessModelFloat(PyLoihiProcessModel): + out_port: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32) + data: np.ndarray = LavaPyType(np.ndarray, np.int32) + + def run_spk(self): + self.out_port.send(self.data) + print("Sent output data of OutPortProcess: ", str(self.data)) + + +# A minimal PyProcModel implementing InPortProcess +@implements(proc=InPortProcess, protocol=LoihiProtocol) +@requires(CPU) +@tag('floating_pt') +class PyInPortProcessModelFloat(PyLoihiProcessModel): + in_port: PyInPort = LavaPyType(PyInPort.VEC_DENSE, np.int32) + data: np.ndarray = LavaPyType(np.ndarray, np.int32) + + def run_spk(self): + self.data[:] = self.in_port.recv() + print("Received input data for InPortProcess: ", str(self.data)) + + +# A minimal hierarchical ProcModel with a nested OutPortProcess +@implements(proc=HOutPortProcess) +class SubHOutPortProcModel(AbstractSubProcessModel): + def __init__(self, proc): + self.out_proc = OutPortProcess(data=proc.proc_params['data']) + self.out_proc.out_port.connect(proc.out_port) + + +# A minimal hierarchical ProcModel with a nested InPortProcess and an aliased +# Var +@implements(proc=HInPortProcess) +class SubHInPortProcModel(AbstractSubProcessModel): + def __init__(self, proc): + self.in_proc = InPortProcess(shape=proc.proc_params['shape']) + proc.in_port.connect(self.in_proc.in_port) + proc.data.alias(self.in_proc.data) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/lava/magma/core/process/test_virtual_ports_in_process.py b/tests/lava/magma/core/process/test_virtual_ports_in_process.py deleted file mode 100644 index 84c086bb2..000000000 --- a/tests/lava/magma/core/process/test_virtual_ports_in_process.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (C) 2021 Intel Corporation -# SPDX-License-Identifier: BSD-3-Clause -import unittest -import numpy as np - -from lava.magma.core.decorator import implements, requires -from lava.magma.core.model.py.model import AbstractPyProcessModel -from lava.magma.core.model.py.ports import PyInPort, PyOutPort -from lava.magma.core.model.py.type import LavaPyType -from lava.magma.core.process.ports.ports import ( - InPort, - OutPort, -) -from lava.magma.core.process.process import AbstractProcess -from lava.magma.core.resources import CPU -from lava.magma.core.run_conditions import RunSteps -from lava.magma.core.run_configs import RunConfig - - -class TestVirtualPorts(unittest.TestCase): - """Contains unit tests around virtual ports created as part of a Process.""" - - def setUp(self) -> None: - # minimal process with an OutPort - pass - - @unittest.skip("skip while solved") - def test_multi_inports(self): - sender = P1() - recv1 = P2() - recv2 = P2() - recv3 = P2() - - # An OutPort can connect to multiple InPorts - # Either at once... - sender.out.connect([recv1.inp, recv2.inp, recv3.inp]) - - sender = P1() - recv1 = P2() - recv2 = P2() - recv3 = P2() - - # ... or consecutively - sender.out.connect(recv1.inp) - sender.out.connect(recv2.inp) - sender.out.connect(recv3.inp) - sender.run(RunSteps(num_steps=2), MyRunCfg()) - - @unittest.skip("skip while solved") - def test_reshape(self): - """Checks reshaping of a port.""" - sender = P1(shape=(1, 6)) - recv = P2(shape=(2, 3)) - - # Using reshape(..), ports with different shape can be connected as - # long as total number of elements does not change - sender.out.reshape((2, 3)).connect(recv.inp) - sender.run(RunSteps(num_steps=2), MyRunCfg()) - - @unittest.skip("skip while solved") - def test_concat(self): - """Checks concatenation of ports.""" - sender1 = P1(shape=(1, 2)) - sender2 = P1(shape=(1, 2)) - sender3 = P1(shape=(1, 2)) - recv = P2(shape=(3, 2)) - - # concat_with(..) concatenates calling port (sender1.out) with - # other ports (sender2.out, sender3.out) along given axis - cp = sender1.out.concat_with([sender2.out, sender3.out], axis=0) - - # The return value is a virtual ConcatPort which can be connected - # to the input port - cp.connect(recv.inp) - sender1.run(RunSteps(num_steps=2), MyRunCfg()) - - -# minimal process with an OutPort -class P1(AbstractProcess): - def __init__(self, **kwargs): - super().__init__(**kwargs) - shape = kwargs.get('shape', (3,)) - self.out = OutPort(shape=shape) - - -# minimal process with an InPort -class P2(AbstractProcess): - def __init__(self, **kwargs): - super().__init__(**kwargs) - shape = kwargs.get('shape', (3,)) - self.inp = InPort(shape=shape) - - -# A minimal PyProcModel implementing P1 -@implements(proc=P1) -@requires(CPU) -class PyProcModelA(AbstractPyProcessModel): - out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int) - - def run(self): - data = np.asarray([1, 2, 3]) - self.out.send(data) - print("Sent output data of P1: ", str(data)) - - -# A minimal PyProcModel implementing P2 -@implements(proc=P2) -@requires(CPU) -class PyProcModelB(AbstractPyProcessModel): - inp: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int) - - def run(self): - in_data = self.inp.recv() - print("Received input data for P2: ", str(in_data)) - - -class MyRunCfg(RunConfig): - def select(self, proc, proc_models): - return proc_models[0] - - -if __name__ == '__main__': - unittest.main()