diff --git a/src/lava/magma/compiler/builder.py b/src/lava/magma/compiler/builder.py index 239b95d32..db74adda6 100644 --- a/src/lava/magma/compiler/builder.py +++ b/src/lava/magma/compiler/builder.py @@ -5,10 +5,10 @@ import typing as ty from lava.magma.core.sync.protocol import AbstractSyncProtocol -from lava.magma.runtime.runtime_service import ( - PyRuntimeService, - AbstractRuntimeService, -) +from lava.magma.runtime.message_infrastructure.message_infrastructure_interface\ + import MessageInfrastructureInterface +from lava.magma.runtime.runtime_service import PyRuntimeService, \ + AbstractRuntimeService if ty.TYPE_CHECKING: from lava.magma.core.process.process import AbstractProcess @@ -16,26 +16,18 @@ from lava.magma.runtime.runtime import Runtime from abc import ABC, abstractmethod -from multiprocessing.managers import SharedMemoryManager import numpy as np from dataclasses import dataclass -from lava.magma.compiler.channels.pypychannel import ( - PyPyChannel, - CspSendPort, - CspRecvPort, -) +from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort from lava.magma.core.model.py.model import AbstractPyProcessModel from lava.magma.core.model.py.type import LavaPyType from lava.magma.compiler.utils import VarInitializer, PortInitializer -from lava.magma.core.model.py.ports import ( - AbstractPyPort, - PyInPort, - PyOutPort, - PyRefPort, -) -from lava.magma.compiler.channels.interfaces import AbstractCspPort, Channel +from lava.magma.core.model.py.ports import AbstractPyPort, \ + PyInPort, PyOutPort, PyRefPort +from lava.magma.compiler.channels.interfaces import AbstractCspPort, Channel, \ + ChannelType class AbstractProcessBuilder(ABC): @@ -497,19 +489,19 @@ class ChannelBuilderMp(AbstractChannelBuilder): """A ChannelBuilder assuming Python multi-processing is used as messaging and multi processing backbone. """ - - channel_type: ty.Type[Channel] + channel_type: ChannelType src_process: "AbstractProcess" dst_process: "AbstractProcess" src_port_initializer: PortInitializer dst_port_initializer: PortInitializer - def build(self, messaging_infrastructure: SharedMemoryManager) -> Channel: + def build(self, messaging_infrastructure: MessageInfrastructureInterface) \ + -> Channel: """Given the message passing framework builds a channel Parameters ---------- - messaging_infrastructure : SharedMemoryManager + messaging_infrastructure : MessageInfrastructureInterface Returns ------- @@ -521,17 +513,16 @@ def build(self, messaging_infrastructure: SharedMemoryManager) -> Channel: Exception Can't build channel of type specified """ - if self.channel_type == PyPyChannel: - return PyPyChannel( - messaging_infrastructure, - self.src_port_initializer.name, - self.dst_port_initializer.name, - self.src_port_initializer.shape, - self.src_port_initializer.d_type, - self.src_port_initializer.size, - ) - else: - raise Exception(f"Can't build channel of type {self.channel_type}") + channel_class = messaging_infrastructure.channel_class( + channel_type=self.channel_type) + return channel_class( + messaging_infrastructure, + self.src_port_initializer.name, + self.dst_port_initializer.name, + self.src_port_initializer.shape, + self.src_port_initializer.d_type, + self.src_port_initializer.size, + ) @dataclass @@ -539,20 +530,18 @@ class ServiceChannelBuilderMp(AbstractChannelBuilder): """A RuntimeServiceChannelBuilder assuming Python multi-processing is used as messaging and multi processing backbone. """ - - channel_type: ty.Type[Channel] - src_process: ty.Union[AbstractRuntimeServiceBuilder, - "AbstractProcessModel"] - dst_process: ty.Union[AbstractRuntimeServiceBuilder, - "AbstractProcessModel"] + channel_type: ChannelType + src_process: ty.Union[AbstractRuntimeServiceBuilder, "AbstractProcessModel"] + dst_process: ty.Union[AbstractRuntimeServiceBuilder, "AbstractProcessModel"] port_initializer: PortInitializer - def build(self, messaging_infrastructure: SharedMemoryManager) -> Channel: + def build(self, messaging_infrastructure: MessageInfrastructureInterface) \ + -> Channel: """Given the message passing framework builds a channel Parameters ---------- - messaging_infrastructure : SharedMemoryManager + messaging_infrastructure : MessageInfrastructureInterface Returns ------- @@ -564,18 +553,20 @@ def build(self, messaging_infrastructure: SharedMemoryManager) -> Channel: Exception Can't build channel of type specified """ - if self.channel_type == PyPyChannel: - channel_name: str = self.port_initializer.name - return PyPyChannel( - messaging_infrastructure, - channel_name + "_src", - channel_name + "_dst", - self.port_initializer.shape, - self.port_initializer.d_type, - self.port_initializer.size, - ) - else: - raise Exception(f"Can't build channel of type {self.channel_type}") + channel_class = messaging_infrastructure.channel_class( + channel_type=self.channel_type) + + channel_name: str = ( + self.port_initializer.name + ) + return channel_class( + messaging_infrastructure, + channel_name + "_src", + channel_name + "_dst", + self.port_initializer.shape, + self.port_initializer.d_type, + self.port_initializer.size, + ) @dataclass @@ -583,18 +574,18 @@ class RuntimeChannelBuilderMp(AbstractChannelBuilder): """A RuntimeChannelBuilder assuming Python multi-processing is used as messaging and multi processing backbone. """ - - channel_type: ty.Type[Channel] + channel_type: ChannelType src_process: ty.Union[AbstractRuntimeServiceBuilder, ty.Type["Runtime"]] dst_process: ty.Union[AbstractRuntimeServiceBuilder, ty.Type["Runtime"]] port_initializer: PortInitializer - def build(self, messaging_infrastructure: SharedMemoryManager) -> Channel: + def build(self, messaging_infrastructure: MessageInfrastructureInterface) \ + -> Channel: """Given the message passing framework builds a channel Parameters ---------- - messaging_infrastructure : SharedMemoryManager + messaging_infrastructure : MessageInfrastructureInterface Returns ------- @@ -606,15 +597,17 @@ def build(self, messaging_infrastructure: SharedMemoryManager) -> Channel: Exception Can't build channel of type specified """ - if self.channel_type == PyPyChannel: - channel_name: str = self.port_initializer.name - return PyPyChannel( - messaging_infrastructure, - channel_name + "_src", - channel_name + "_dst", - self.port_initializer.shape, - self.port_initializer.d_type, - self.port_initializer.size, - ) - else: - raise Exception(f"Can't build channel of type {self.channel_type}") + channel_class = messaging_infrastructure.channel_class( + channel_type=self.channel_type) + + channel_name: str = ( + self.port_initializer.name + ) + return channel_class( + messaging_infrastructure, + channel_name + "_src", + channel_name + "_dst", + self.port_initializer.shape, + self.port_initializer.d_type, + self.port_initializer.size, + ) diff --git a/src/lava/magma/compiler/channels/interfaces.py b/src/lava/magma/compiler/channels/interfaces.py index 5b3d32446..0275e37d1 100644 --- a/src/lava/magma/compiler/channels/interfaces.py +++ b/src/lava/magma/compiler/channels/interfaces.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: BSD-3-Clause # See: https://spdx.org/licenses/ import typing as ty +from enum import IntEnum + import numpy as np from abc import ABC, abstractmethod @@ -61,3 +63,10 @@ def src_port(self) -> AbstractCspSendPort: @abstractmethod def dst_port(self) -> AbstractCspRecvPort: pass + + +class ChannelType(IntEnum): + """Type of a channel given the two process models""" + PyPy = 0 + CPy = 1 + PyC = 2 diff --git a/src/lava/magma/compiler/channels/pypychannel.py b/src/lava/magma/compiler/channels/pypychannel.py index 7ead9a6f6..f93287c2b 100644 --- a/src/lava/magma/compiler/channels/pypychannel.py +++ b/src/lava/magma/compiler/channels/pypychannel.py @@ -15,6 +15,9 @@ AbstractCspSendPort, AbstractCspRecvPort, ) +if ty.TYPE_CHECKING: + from lava.magma.runtime.message_infrastructure\ + .message_infrastructure_interface import MessageInfrastructureInterface @dataclass @@ -24,7 +27,6 @@ class Proto: nbytes: int -# ToDo: (AW) Do not create any class attributes outside of __init__ class CspSendPort(AbstractCspSendPort): """ CspSendPort is a low level send port implementation based on CSP @@ -32,6 +34,17 @@ class CspSendPort(AbstractCspSendPort): """ def __init__(self, name, shm, proto, size, req, ack): + """Instantiates CspSendPort object and class attributes + + Parameters + ---------- + name : str + shm : SharedMemory + proto : Proto + size : int + req : Pipe + ack : Pipe + """ self._name = name self._shm = shm self._shape = proto.shape @@ -42,6 +55,8 @@ def __init__(self, name, shm, proto, size, req, ack): self._size = size self._idx = 0 self._done = False + self._array = [] + self._semaphore = None self.thread = None @property @@ -66,7 +81,7 @@ def start(self): np.ndarray( shape=self._shape, dtype=self._dtype, - buffer=self._shm.buf[self._nbytes * i : self._nbytes * (i + 1)], + buffer=self._shm.buf[self._nbytes * i: self._nbytes * (i + 1)], ) for i in range(self._size) ] @@ -152,6 +167,17 @@ class CspRecvPort(AbstractCspRecvPort): """ def __init__(self, name, shm, proto, size, req, ack): + """Instantiates CspRecvPort object and class attributes + + Parameters + ---------- + name : str + shm : SharedMemory + proto : Proto + size : int + req : Pipe + ack : Pipe + """ self._name = name self._shm = shm self._shape = proto.shape @@ -162,6 +188,8 @@ def __init__(self, name, shm, proto, size, req, ack): self._ack = ack self._idx = 0 self._done = False + self._array = [] + self._queue = None self.thread = None @property @@ -186,7 +214,7 @@ def start(self): np.ndarray( shape=self._shape, dtype=self._dtype, - buffer=self._shm.buf[self._nbytes * i : self._nbytes * (i + 1)], + buffer=self._shm.buf[self._nbytes * i: self._nbytes * (i + 1)], ) for i in range(self._size) ] @@ -241,8 +269,26 @@ class PyPyChannel(Channel): """Helper class to create the set of send and recv port and encapsulate them inside a common structure. We call this a PyPyChannel""" - def __init__(self, smm, src_name, dst_name, shape, dtype, size): + def __init__(self, + message_infrastructure: 'MessageInfrastructureInterface', + src_name, + dst_name, + shape, + dtype, + size): + """Instantiates PyPyChannel object and class attributes + + Parameters + ---------- + message_infrastructure: MessageInfrastructureInterface + src_name : str + dst_name : str + shape : ty.Tuple[int, ...] + dtype : ty.Type[np.intc] + size : int + """ nbytes = np.prod(shape) * np.dtype(dtype).itemsize + smm = message_infrastructure.smm shm = smm.SharedMemory(int(nbytes * size)) req = Pipe(duplex=False) ack = Pipe(duplex=False) diff --git a/src/lava/magma/compiler/compiler.py b/src/lava/magma/compiler/compiler.py index 1da5e178c..2948fa0d9 100644 --- a/src/lava/magma/compiler/compiler.py +++ b/src/lava/magma/compiler/compiler.py @@ -21,8 +21,7 @@ AbstractRuntimeServiceBuilder, RuntimeServiceBuilder, \ AbstractChannelBuilder, ServiceChannelBuilderMp from lava.magma.compiler.builder import RuntimeChannelBuilderMp -from lava.magma.compiler.channels.interfaces import Channel -from lava.magma.compiler.channels.pypychannel import PyPyChannel +from lava.magma.compiler.channels.interfaces import ChannelType from lava.magma.compiler.executable import Executable from lava.magma.compiler.node import NodeConfig, Node from lava.magma.compiler.utils import VarInitializer, PortInitializer @@ -483,14 +482,14 @@ def _create_node_cfgs(proc_map: PROC_MAP) -> ty.List[NodeConfig]: @staticmethod def _get_channel_type(src: ty.Type[AbstractProcessModel], dst: ty.Type[AbstractProcessModel]) \ - -> ty.Type[Channel]: + -> ChannelType: """Returns appropriate ChannelType for a given (source, destination) pair of ProcessModels.""" if issubclass(src, AbstractPyProcessModel) and issubclass( dst, AbstractPyProcessModel ): - return PyPyChannel + return ChannelType.PyPy else: raise NotImplementedError( f"No support for (source, destination) pairs of type " @@ -596,7 +595,7 @@ def _create_sync_channel_builders( sync_channel_builders: ty.List[AbstractChannelBuilder] = [] for sync_domain in rsb: runtime_to_service_cmd = \ - RuntimeChannelBuilderMp(PyPyChannel, + RuntimeChannelBuilderMp(ChannelType.PyPy, Runtime, rsb[sync_domain], self._create_mgmt_port_initializer( @@ -605,7 +604,7 @@ def _create_sync_channel_builders( sync_channel_builders.append(runtime_to_service_cmd) service_to_runtime_ack = \ - RuntimeChannelBuilderMp(PyPyChannel, + RuntimeChannelBuilderMp(ChannelType.PyPy, rsb[sync_domain], Runtime, self._create_mgmt_port_initializer( @@ -614,7 +613,7 @@ def _create_sync_channel_builders( sync_channel_builders.append(service_to_runtime_ack) runtime_to_service_req = \ - RuntimeChannelBuilderMp(PyPyChannel, + RuntimeChannelBuilderMp(ChannelType.PyPy, Runtime, rsb[sync_domain], self._create_mgmt_port_initializer( @@ -623,7 +622,7 @@ def _create_sync_channel_builders( sync_channel_builders.append(runtime_to_service_req) service_to_runtime_data = \ - RuntimeChannelBuilderMp(PyPyChannel, + RuntimeChannelBuilderMp(ChannelType.PyPy, rsb[sync_domain], Runtime, self._create_mgmt_port_initializer( @@ -632,7 +631,7 @@ def _create_sync_channel_builders( sync_channel_builders.append(service_to_runtime_data) runtime_to_service_data = \ - RuntimeChannelBuilderMp(PyPyChannel, + RuntimeChannelBuilderMp(ChannelType.PyPy, Runtime, rsb[sync_domain], self._create_mgmt_port_initializer( @@ -642,7 +641,7 @@ def _create_sync_channel_builders( for process in sync_domain.processes: service_to_process_cmd = \ - ServiceChannelBuilderMp(PyPyChannel, + ServiceChannelBuilderMp(ChannelType.PyPy, rsb[sync_domain], process, self._create_mgmt_port_initializer( @@ -651,7 +650,7 @@ def _create_sync_channel_builders( sync_channel_builders.append(service_to_process_cmd) process_to_service_ack = \ - ServiceChannelBuilderMp(PyPyChannel, + ServiceChannelBuilderMp(ChannelType.PyPy, process, rsb[sync_domain], self._create_mgmt_port_initializer( @@ -660,7 +659,7 @@ def _create_sync_channel_builders( sync_channel_builders.append(process_to_service_ack) service_to_process_req = \ - ServiceChannelBuilderMp(PyPyChannel, + ServiceChannelBuilderMp(ChannelType.PyPy, rsb[sync_domain], process, self._create_mgmt_port_initializer( @@ -669,7 +668,7 @@ def _create_sync_channel_builders( sync_channel_builders.append(service_to_process_req) process_to_service_data = \ - ServiceChannelBuilderMp(PyPyChannel, + ServiceChannelBuilderMp(ChannelType.PyPy, process, rsb[sync_domain], self._create_mgmt_port_initializer( @@ -678,7 +677,7 @@ def _create_sync_channel_builders( sync_channel_builders.append(process_to_service_data) service_to_process_data = \ - ServiceChannelBuilderMp(PyPyChannel, + ServiceChannelBuilderMp(ChannelType.PyPy, rsb[sync_domain], process, self._create_mgmt_port_initializer( diff --git a/src/lava/magma/core/process/message_interface_enum.py b/src/lava/magma/core/process/message_interface_enum.py new file mode 100644 index 000000000..2e5c5188a --- /dev/null +++ b/src/lava/magma/core/process/message_interface_enum.py @@ -0,0 +1,8 @@ +# Copyright (C) 2021 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ +from enum import IntEnum + + +class ActorType(IntEnum): + MultiProcessing = 0 diff --git a/src/lava/magma/core/process/process.py b/src/lava/magma/core/process/process.py index 27f242221..6239072f3 100644 --- a/src/lava/magma/core/process/process.py +++ b/src/lava/magma/core/process/process.py @@ -4,6 +4,7 @@ import typing as ty from _collections import OrderedDict +from lava.magma.core.process.message_interface_enum import ActorType from lava.magma.core.run_conditions import AbstractRunCondition from lava.magma.core.run_configs import RunConfig from lava.magma.core.process.ports.ports import \ @@ -377,7 +378,9 @@ def run(self, condition: AbstractRunCondition, run_cfg: RunConfig): """ if not self._runtime: executable = self.compile(run_cfg) - self._runtime = Runtime(condition, executable) + self._runtime = Runtime(condition, + executable, + ActorType.MultiProcessing) self._runtime.initialize() self._runtime.start(condition) diff --git a/src/lava/magma/runtime/message_infrastructure/__init__.py b/src/lava/magma/runtime/message_infrastructure/__init__.py new file mode 100644 index 000000000..fc778f31a --- /dev/null +++ b/src/lava/magma/runtime/message_infrastructure/__init__.py @@ -0,0 +1,5 @@ +# Copyright (C) 2021 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +__import__("pkg_resources").declare_namespace(__name__) diff --git a/src/lava/magma/runtime/message_infrastructure/factory.py b/src/lava/magma/runtime/message_infrastructure/factory.py new file mode 100644 index 000000000..7db90bd91 --- /dev/null +++ b/src/lava/magma/runtime/message_infrastructure/factory.py @@ -0,0 +1,17 @@ +# Copyright (C) 2021 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ +from lava.magma.core.process.message_interface_enum import ActorType +from lava.magma.runtime.message_infrastructure.multiprocessing import \ + MultiProcessing + + +class MessageInfrastructureFactory: + """Creates the message infrastructure instance based on type""" + @staticmethod + def create(factory_type: ActorType): + """type of actor framework being chosen""" + if factory_type == ActorType.MultiProcessing: + return MultiProcessing() + else: + raise Exception("Unsupported factory_type") diff --git a/src/lava/magma/runtime/message_infrastructure/message_infrastructure_interface.py b/src/lava/magma/runtime/message_infrastructure/message_infrastructure_interface.py new file mode 100644 index 000000000..7b6b2cd44 --- /dev/null +++ b/src/lava/magma/runtime/message_infrastructure/message_infrastructure_interface.py @@ -0,0 +1,44 @@ +# Copyright (C) 2021 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ +import typing as ty +if ty.TYPE_CHECKING: + from lava.magma.core.process.process import AbstractProcess + from lava.magma.compiler.builder import AbstractRuntimeServiceBuilder, \ + PyProcessBuilder + +from abc import ABC, abstractmethod + +from lava.magma.compiler.channels.interfaces import ChannelType, Channel +from lava.magma.core.sync.domain import SyncDomain + + +class MessageInfrastructureInterface(ABC): + """Interface to provide the ability to create actors which can + communicate via message passing""" + @abstractmethod + def start(self): + """Starts the messaging infrastructure""" + pass + + @abstractmethod + def stop(self): + """Stops the messaging infrastructure""" + pass + + @abstractmethod + def build_actor(self, target_fn: ty.Callable, builder: ty.Union[ + ty.Dict['AbstractProcess', 'PyProcessBuilder'], ty.Dict[ + SyncDomain, 'AbstractRuntimeServiceBuilder']]): + """Given a target_fn starts a system process""" + pass + + @property + @abstractmethod + def actors(self) -> ty.List[ty.Any]: + """Returns a list of actors""" + pass + + @abstractmethod + def channel_class(self, channel_type: ChannelType) -> ty.Type[Channel]: + pass diff --git a/src/lava/magma/runtime/message_infrastructure/multiprocessing.py b/src/lava/magma/runtime/message_infrastructure/multiprocessing.py new file mode 100644 index 000000000..b2d4f7396 --- /dev/null +++ b/src/lava/magma/runtime/message_infrastructure/multiprocessing.py @@ -0,0 +1,64 @@ +# Copyright (C) 2021 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ +import typing as ty +if ty.TYPE_CHECKING: + from lava.magma.core.process.process import AbstractProcess + from lava.magma.compiler.builder import PyProcessBuilder, \ + AbstractRuntimeServiceBuilder + +from multiprocessing import Process as SystemProcess +from multiprocessing.managers import SharedMemoryManager + +from lava.magma.compiler.channels.interfaces import ChannelType, Channel +from lava.magma.compiler.channels.pypychannel import PyPyChannel + +from lava.magma.core.sync.domain import SyncDomain +from lava.magma.runtime.message_infrastructure.message_infrastructure_interface\ + import MessageInfrastructureInterface + + +class MultiProcessing(MessageInfrastructureInterface): + """Implements message passing using shared memory and multiprocessing""" + def __init__(self): + self._smm: ty.Optional[SharedMemoryManager] = None + self._actors: ty.List[SystemProcess] = [] + + @property + def actors(self): + """Returns a list of actors""" + return self._actors + + @property + def smm(self): + """Returns the underlying shared memory manager""" + return self._smm + + def start(self): + """Starts the shared memory manager""" + self._smm = SharedMemoryManager() + self._smm.start() + + def build_actor(self, target_fn: ty.Callable, builder: ty.Union[ + ty.Dict['AbstractProcess', 'PyProcessBuilder'], ty.Dict[ + SyncDomain, 'AbstractRuntimeServiceBuilder']]) -> ty.Any: + """Given a target_fn starts a system (os) process""" + system_process = SystemProcess(target=target_fn, + args=(), + kwargs={"builder": builder}) + system_process.start() + self._actors.append(system_process) + return system_process + + def stop(self): + """Stops the shared memory manager""" + for actor in self._actors: + actor.join() + + self._smm.shutdown() + + def channel_class(self, channel_type: ChannelType) -> ty.Type[Channel]: + if channel_type == ChannelType.PyPy: + return PyPyChannel + else: + raise Exception(f"Unsupported channel type {channel_type}") diff --git a/src/lava/magma/runtime/runtime.py b/src/lava/magma/runtime/runtime.py index c899e4dc2..d259e1633 100644 --- a/src/lava/magma/runtime/runtime.py +++ b/src/lava/magma/runtime/runtime.py @@ -9,14 +9,17 @@ from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort from lava.magma.compiler.exec_var import AbstractExecVar +from lava.magma.core.process.message_interface_enum import ActorType +from lava.magma.runtime.message_infrastructure.message_infrastructure_interface\ + import MessageInfrastructureInterface +from lava.magma.runtime.message_infrastructure.factory import \ + MessageInfrastructureFactory from lava.magma.runtime.mgmt_token_enums import MGMT_COMMAND, MGMT_RESPONSE, \ enum_to_np, REQ_TYPE from lava.magma.runtime.runtime_service import AsyncPyRuntimeService if ty.TYPE_CHECKING: from lava.magma.core.process.process import AbstractProcess -from multiprocessing import Process as UnixProcess -from multiprocessing.managers import SharedMemoryManager from lava.magma.compiler.builder import AbstractProcessBuilder, \ RuntimeChannelBuilderMp, ServiceChannelBuilderMp, \ RuntimeServiceBuilder @@ -40,13 +43,17 @@ class Runtime: the APIs to start, pause, stop and wait on an execution. Execution could be blocking and non-blocking as specified by the run run_condition.""" - def __init__(self, run_cond: AbstractRunCondition, exe: Executable): + def __init__(self, + run_cond: AbstractRunCondition, + exe: Executable, + message_infrastructure_type: ActorType): self._run_cond: AbstractRunCondition = run_cond self._executable: Executable = exe - # Abstract the SharedMemoryManager to Generic Messaging Infrastructure - self._messaging_infrastructure: ty.Optional[SharedMemoryManager] = None - self._actors: ty.List[UnixProcess] = [] + self._messaging_infrastructure_type: ActorType = \ + message_infrastructure_type + self._messaging_infrastructure: \ + ty.Optional[MessageInfrastructureInterface] = None self.current_ts = 0 self._is_initialized = False self._is_running = False @@ -99,16 +106,9 @@ def node_cfg(self) -> NodeConfig: """Returns the selected NodeCfg.""" return self._executable.node_configs[0] - def _build_mp_actor(self, target_fn, builder): - """Given a target_fn starts a unix process""" - unix_process = UnixProcess(target=target_fn, - args=(), - kwargs={"builder": builder}) - unix_process.start() - self._actors.append(unix_process) - def _build_message_infrastructure(self): - self._messaging_infrastructure = SharedMemoryManager() + self._messaging_infrastructure = MessageInfrastructureFactory.create( + self._messaging_infrastructure_type) self._messaging_infrastructure.start() def _get_process_builder_for_process(self, process): @@ -178,7 +178,8 @@ def _build_sync_channels(self): # ToDo: (AW) Why not pass the builder as an argument to the mp.Process # constructor which will then be passed to the target function? def _build_processes(self): - process_builders_collection = [ + process_builders_collection: ty.List[ + ty.Dict[AbstractProcess, AbstractProcessBuilder]] = [ self._executable.py_builders, self._executable.c_builders, self._executable.nc_builders, @@ -189,15 +190,17 @@ def _build_processes(self): for proc, proc_builder in process_builders.items(): # Assign current Runtime to process proc._runtime = self - self._build_mp_actor(target_fn=target_fn, - builder=proc_builder) + self._messaging_infrastructure.build_actor( + target_fn=target_fn, + builder=proc_builder) def _build_runtime_services(self): runtime_service_builders = self._executable.rs_builders if self._executable.rs_builders: for sd, rs_builder in runtime_service_builders.items(): - self._build_mp_actor(target_fn=target_fn, - builder=rs_builder) + self._messaging_infrastructure.build_actor( + target_fn=target_fn, + builder=rs_builder) def start(self, run_condition: AbstractRunCondition): if self._is_initialized: @@ -259,7 +262,7 @@ def stop(self): else: print("Runtime not started yet.") finally: - self._messaging_infrastructure.shutdown() + self._messaging_infrastructure.stop() def join(self): """Join all ports and processes""" @@ -273,8 +276,6 @@ def join(self): port.join() for port in self.runtime_to_service_data: port.join() - for actor in self._actors: - actor.join() @property def global_time(self): diff --git a/tests/lava/magma/compiler/channels/test_pypychannel.py b/tests/lava/magma/compiler/channels/test_pypychannel.py index 1c7cd853d..ff6af07c2 100644 --- a/tests/lava/magma/compiler/channels/test_pypychannel.py +++ b/tests/lava/magma/compiler/channels/test_pypychannel.py @@ -6,10 +6,20 @@ from lava.magma.compiler.channels.pypychannel import PyPyChannel +class MockInterface: + def __init__(self, smm): + self.smm = smm + + def get_channel(smm, data, size, name="test_channel") -> PyPyChannel: + mock = MockInterface(smm) channel = PyPyChannel( - smm=smm, src_name=name, dst_name=name, shape=data.shape, - dtype=data.dtype, size=size + message_infrastructure=mock, + src_name=name, + dst_name=name, + shape=data.shape, + dtype=data.dtype, + size=size ) return channel diff --git a/tests/lava/magma/compiler/test_channel_builder.py b/tests/lava/magma/compiler/test_channel_builder.py index a3564404f..e7465b72c 100644 --- a/tests/lava/magma/compiler/test_channel_builder.py +++ b/tests/lava/magma/compiler/test_channel_builder.py @@ -1,10 +1,14 @@ +# Copyright (C) 2021 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ +import typing as ty import unittest from multiprocessing.managers import SharedMemoryManager import numpy as np from lava.magma.compiler.builder import ChannelBuilderMp -from lava.magma.compiler.channels.interfaces import Channel +from lava.magma.compiler.channels.interfaces import Channel, ChannelType from lava.magma.compiler.utils import PortInitializer from lava.magma.compiler.channels.pypychannel import ( PyPyChannel, @@ -13,6 +17,14 @@ ) +class MockMessageInterface: + def __init__(self, smm): + self.smm = smm + + def channel_class(self, channel_type: ChannelType) -> ty.Type: + return PyPyChannel + + # ToDo: (AW) This test does not work for me. Something broken with d_type. # SMM does not seem to support numpy types. class TestChannelBuilder(unittest.TestCase): @@ -24,7 +36,7 @@ def test_channel_builder(self): name="mock", shape=(1, 2), d_type=np.int32, port_type='DOESNOTMATTER', size=64) channel_builder: ChannelBuilderMp = ChannelBuilderMp( - channel_type=PyPyChannel, + channel_type=ChannelType.PyPy, src_port_initializer=port_initializer, dst_port_initializer=port_initializer, src_process=None, @@ -32,7 +44,8 @@ def test_channel_builder(self): ) smm.start() - channel: Channel = channel_builder.build(smm) + mock = MockMessageInterface(smm) + channel: Channel = channel_builder.build(mock) assert isinstance(channel, PyPyChannel) assert isinstance(channel.src_port, CspSendPort) assert isinstance(channel.dst_port, CspRecvPort) diff --git a/tests/lava/magma/compiler/test_compiler.py b/tests/lava/magma/compiler/test_compiler.py index b8409015f..385a2dfc0 100644 --- a/tests/lava/magma/compiler/test_compiler.py +++ b/tests/lava/magma/compiler/test_compiler.py @@ -3,6 +3,7 @@ # See: https://spdx.org/licenses/ import unittest +from lava.magma.compiler.channels.interfaces import ChannelType from lava.magma.compiler.compiler import Compiler import lava.magma.compiler.exceptions as ex from lava.magma.core.decorator import implements, requires @@ -19,7 +20,6 @@ from lava.magma.core.model.py.type import LavaPyType from lava.magma.core.process.variable import Var, VarServer from lava.magma.core.resources import CPU -from lava.magma.compiler.channels.pypychannel import PyPyChannel # minimal process with an InPort and OutPortA @@ -617,7 +617,7 @@ def test_create_channel_builders(self): # Each channel builder should connect its source and destination # process and port # Let's check the first one in detail - self.assertEqual(cbs[0].channel_type, PyPyChannel) + self.assertEqual(cbs[0].channel_type, ChannelType.PyPy) self.assertEqual(cbs[0].src_process, p1) self.assertEqual(cbs[0].src_port_initializer.name, "out") self.assertEqual(cbs[0].src_port_initializer.shape, (1,)) diff --git a/tests/lava/magma/runtime/test_runtime.py b/tests/lava/magma/runtime/test_runtime.py index a0ae0bcf1..bf7b7f643 100644 --- a/tests/lava/magma/runtime/test_runtime.py +++ b/tests/lava/magma/runtime/test_runtime.py @@ -2,6 +2,7 @@ import unittest from lava.magma.compiler.executable import Executable +from lava.magma.core.process.message_interface_enum import ActorType from lava.magma.core.resources import HeadNode from lava.magma.core.run_conditions import RunSteps, AbstractRunCondition from lava.magma.compiler.node import Node, NodeConfig @@ -10,38 +11,44 @@ class TestRuntime(unittest.TestCase): def test_runtime_creation(self): + """Tests runtime constructor""" exe: Executable = Executable() run_cond: AbstractRunCondition = RunSteps(num_steps=10) - runtime: Runtime = Runtime(run_cond=run_cond, exe=exe) + mp = ActorType.MultiProcessing + runtime: Runtime = Runtime(run_cond=run_cond, + exe=exe, + message_infrastructure_type=mp) expected_type: ty.Type = Runtime assert isinstance( runtime, expected_type ), f"Expected type {expected_type} doesn't match {(type(runtime))}" def test_executable_node_config_assertion(self): + """Tests runtime constructions with expected constraints""" exec: Executable = Executable() run_cond: AbstractRunCondition = RunSteps(num_steps=10) - runtime1: Runtime = Runtime(run_cond, exec) + runtime1: Runtime = Runtime(run_cond, exec, ActorType.MultiProcessing) with self.assertRaises(AssertionError): runtime1.initialize() node: Node = Node(HeadNode, []) exec.node_configs.append(NodeConfig([node])) - runtime2: Runtime = Runtime(run_cond, exec) + runtime2: Runtime = Runtime(run_cond, exec, ActorType.MultiProcessing) runtime2.initialize() expected_type: ty.Type = Runtime assert isinstance( runtime2, expected_type ), f"Expected type {expected_type} doesn't match {(type(runtime2))}" + runtime2.stop() exec.node_configs[0].append(node) - runtime3: Runtime = Runtime(run_cond, exec) + runtime3: Runtime = Runtime(run_cond, exec, ActorType.MultiProcessing) with self.assertRaises(AssertionError): runtime3.initialize() exec.node_configs.append(NodeConfig([node])) - runtime4: Runtime = Runtime(run_cond, exec) + runtime4: Runtime = Runtime(run_cond, exec, ActorType.MultiProcessing) with self.assertRaises(AssertionError): runtime4.initialize() diff --git a/tests/lava/magma/runtime/test_runtime_service.py b/tests/lava/magma/runtime/test_runtime_service.py index b46630980..2fac5b376 100644 --- a/tests/lava/magma/runtime/test_runtime_service.py +++ b/tests/lava/magma/runtime/test_runtime_service.py @@ -11,9 +11,15 @@ from lava.magma.runtime.runtime_service import PyRuntimeService -def create_channel(messaging_infrastructure: SharedMemoryManager, name: str): +class MockInterface: + def __init__(self, smm): + self.smm = smm + + +def create_channel(smm: SharedMemoryManager, name: str): + mock = MockInterface(smm=smm) return PyPyChannel( - messaging_infrastructure, + mock, name, name, (1,),