Skip to content

Commit

Permalink
Alternative to Injector/Extractor Processes (lava-nc#835)
Browse files Browse the repository at this point in the history
* prototype implementing injector/extractor function, not wrapped in a Process

* modified injector and extractor classes

* fixed linting

---------

Co-authored-by: PhilippPlank <[email protected]>
Co-authored-by: Philipp Plank <[email protected]>
  • Loading branch information
3 people authored Mar 19, 2024
1 parent b50cdea commit 1cb36bd
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 106 deletions.
20 changes: 20 additions & 0 deletions src/lava/magma/core/process/ports/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,17 @@ class OutPort(AbstractIOPort, AbstractSrcPort):
sub processes.
"""

def __init__(self, shape: ty.Tuple[int, ...]):
super().__init__(shape)
self.external_pipe_flag = False
self.external_pipe_buffer_size = 64

def flag_external_pipe(self, buffer_size=None):
self.external_pipe_flag = True

if buffer_size is not None:
self.external_pipe_buffer_size = buffer_size

def connect(
self,
ports: ty.Union["AbstractIOPort", ty.List["AbstractIOPort"]],
Expand Down Expand Up @@ -493,6 +504,15 @@ def __init__(
super().__init__(shape)
self._reduce_op = reduce_op

self.external_pipe_flag = False
self.external_pipe_buffer_size = 64

def flag_external_pipe(self, buffer_size=None):
self.external_pipe_flag = True

if buffer_size is not None:
self.external_pipe_buffer_size = buffer_size

def connect(self,
ports: ty.Union["InPort", ty.List["InPort"]],
connection_configs: ty.Optional[ConnectionConfigs] = None):
Expand Down
46 changes: 44 additions & 2 deletions src/lava/magma/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
if ty.TYPE_CHECKING:
from lava.magma.core.process.process import AbstractProcess
from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort, \
CspSelector
CspSelector, PyPyChannel
from lava.magma.compiler.builders.channel_builder import (
ChannelBuilderMp, RuntimeChannelBuilderMp, ServiceChannelBuilderMp,
ChannelBuilderPyNc)
Expand All @@ -38,7 +38,7 @@
ChannelType
from lava.magma.compiler.executable import Executable
from lava.magma.compiler.node import NodeConfig
from lava.magma.core.process.ports.ports import create_port_id
from lava.magma.core.process.ports.ports import create_port_id, InPort, OutPort
from lava.magma.core.run_conditions import (AbstractRunCondition,
RunContinuous, RunSteps)
from lava.magma.compiler.channels.watchdog import WatchdogManagerInterface
Expand Down Expand Up @@ -308,6 +308,10 @@ def _build_processes(self):
proc._runtime = self
exception_q = Queue()
self.exception_q.append(exception_q)

# Create any external pypychannels
self._create_external_channels(proc, proc_builder)

self._messaging_infrastructure.build_actor(target_fn,
proc_builder,
exception_q)
Expand All @@ -323,6 +327,44 @@ def _build_runtime_services(self):
rs_builder,
self.exception_q[-1])

def _create_external_channels(self,
proc: AbstractProcess,
proc_builder: AbstractProcessBuilder):
"""Creates a csp channel which can be connected to/from a
non-procss/Lava python environment. This enables I/O to Lava from
external sources."""
for name, py_port in proc_builder.py_ports.items():
port = getattr(proc, name)

if port.external_pipe_flag:
if isinstance(port, InPort):
pypychannel = PyPyChannel(
message_infrastructure=self._messaging_infrastructure,
src_name="src",
dst_name=name,
shape=py_port.shape,
dtype=py_port.d_type,
size=port.external_pipe_buffer_size)

proc_builder.set_csp_ports([pypychannel.dst_port])

port.external_pipe_csp_send_port = pypychannel.src_port
port.external_pipe_csp_send_port.start()

if isinstance(port, OutPort):
pypychannel = PyPyChannel(
message_infrastructure=self._messaging_infrastructure,
src_name=name,
dst_name="dst",
shape=py_port.shape,
dtype=py_port.d_type,
size=port.external_pipe_buffer_size)

proc_builder.set_csp_ports([pypychannel.src_port])

port.external_pipe_csp_recv_port = pypychannel.dst_port
port.external_pipe_csp_recv_port.start()

def _get_resp_for_run(self):
"""
Gets response from RuntimeServices
Expand Down
59 changes: 23 additions & 36 deletions src/lava/proc/io/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import typing as ty

from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.ports.ports import InPort, RefPort, Var
from lava.magma.core.process.ports.ports import InPort, OutPort, RefPort, Var
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.model.py.ports import PyInPort, PyRefPort
from lava.magma.core.model.py.ports import PyInPort, PyOutPort, PyRefPort
from lava.magma.compiler.channels.pypychannel import PyPyChannel
from lava.magma.runtime.message_infrastructure.multiprocessing import \
MultiProcessing
Expand Down Expand Up @@ -51,9 +51,9 @@ class Extractor(AbstractProcess):
def __init__(self,
shape: ty.Tuple[int, ...],
buffer_size: ty.Optional[int] = 50,
channel_config: ty.Optional[utils.ChannelConfig] = None) -> \
None:
super().__init__()
channel_config: ty.Optional[utils.ChannelConfig] = None,
**kwargs) -> None:
super().__init__(shape_1=shape, **kwargs)

channel_config = channel_config or utils.ChannelConfig()

Expand All @@ -63,78 +63,65 @@ def __init__(self,

self._shape = shape

self._multi_processing = MultiProcessing()
self._multi_processing.start()

# Stands for ProcessModel to Process
pm_to_p = PyPyChannel(message_infrastructure=self._multi_processing,
src_name="src",
dst_name="dst",
shape=self._shape,
dtype=float,
size=buffer_size)
self._pm_to_p_dst_port = pm_to_p.dst_port
self._pm_to_p_dst_port.start()

self.proc_params["channel_config"] = channel_config
self.proc_params["pm_to_p_src_port"] = pm_to_p.src_port

self._receive_when_empty = channel_config.get_receive_empty_function()
self._receive_when_not_empty = \
channel_config.get_receive_not_empty_function()

self.in_port = InPort(shape=self._shape)
self.in_port = InPort(shape=shape)
self.out_port = OutPort(shape=shape)
self.out_port.flag_external_pipe(buffer_size=buffer_size)

def receive(self) -> np.ndarray:
"""Receive data from the ProcessModel.
The data is received from pm_to_p.dst_port.
The data is received from out_port.
Returns
----------
data : np.ndarray
Data received.
"""
elements_in_buffer = self._pm_to_p_dst_port._queue.qsize()
if not hasattr(self.out_port, 'external_pipe_csp_recv_port'):
raise AssertionError("The Runtime needs to be created before"
"calling <send>. Please use the method "
"<create_runtime> or <run> on your Lava"
" network before using <send>.")

elements_in_buffer = \
self.out_port.external_pipe_csp_recv_port._queue.qsize()

if elements_in_buffer == 0:
data = self._receive_when_empty(
self._pm_to_p_dst_port,
self.out_port.external_pipe_csp_recv_port,
np.zeros(self._shape))
else:
data = self._receive_when_not_empty(
self._pm_to_p_dst_port,
self.out_port.external_pipe_csp_recv_port,
np.zeros(self._shape),
elements_in_buffer)

return data

def __del__(self) -> None:
super().__del__()

self._multi_processing.stop()
self._pm_to_p_dst_port.join()


@implements(proc=Extractor, protocol=LoihiProtocol)
@requires(CPU)
class PyLoihiExtractorModel(PyLoihiProcessModel):
"""PyLoihiProcessModel for the Extractor Process."""
in_port: PyInPort = LavaPyType(PyInPort.VEC_DENSE, float)
out_port: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)

def __init__(self, proc_params: dict) -> None:
super().__init__(proc_params=proc_params)

channel_config = self.proc_params["channel_config"]
self._pm_to_p_src_port = self.proc_params["pm_to_p_src_port"]
self._pm_to_p_src_port.start()

self._send = channel_config.get_send_full_function()

def run_spk(self) -> None:
self._send(self._pm_to_p_src_port, self.in_port.recv())

def __del__(self) -> None:
self._pm_to_p_src_port.join()
self._send(self.out_port.csp_ports[-1],
self.in_port.recv())


class VarWire(AbstractProcess):
Expand Down
69 changes: 26 additions & 43 deletions src/lava/proc/io/injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
import typing as ty

from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.ports.ports import OutPort
from lava.magma.core.process.ports.ports import InPort, OutPort
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.model.py.ports import PyOutPort
from lava.magma.runtime.message_infrastructure.multiprocessing import \
MultiProcessing
from lava.magma.compiler.channels.pypychannel import PyPyChannel
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.proc.io import utils


Expand Down Expand Up @@ -47,74 +44,63 @@ class Injector(AbstractProcess):
buffer is full and how the dst_port behaves when the buffer is empty
and not empty.
"""

def __init__(self,
shape: ty.Tuple[int, ...],
buffer_size: ty.Optional[int] = 50,
channel_config: ty.Optional[utils.ChannelConfig] = None) -> \
None:
super().__init__()
channel_config: ty.Optional[utils.ChannelConfig] = None,
**kwargs) -> None:
super().__init__(shape_1=shape, **kwargs)

channel_config = channel_config or utils.ChannelConfig()

utils.validate_shape(shape)
utils.validate_buffer_size(buffer_size)
utils.validate_channel_config(channel_config)

self._multi_processing = MultiProcessing()
self._multi_processing.start()

# Stands for Process to ProcessModel
p_to_pm = PyPyChannel(message_infrastructure=self._multi_processing,
src_name="src",
dst_name="dst",
shape=shape,
dtype=float,
size=buffer_size)
self._p_to_pm_src_port = p_to_pm.src_port
self._p_to_pm_src_port.start()
self.in_port = InPort(shape=shape)
self.in_port.flag_external_pipe(buffer_size=buffer_size)
self.out_port = OutPort(shape=shape)

self.proc_params["shape"] = shape
self.proc_params["channel_config"] = channel_config
self.proc_params["p_to_pm_dst_port"] = p_to_pm.dst_port

self._send = channel_config.get_send_full_function()

self.out_port = OutPort(shape=shape)

def send(self, data: np.ndarray) -> None:
"""Send data to the ProcessModel.
The data is sent through p_to_pm.src_port.
"""Send data to connected process.
Parameters
----------
data : np.ndarray
Data to be sent.
"""
self._send(self._p_to_pm_src_port, data)

def __del__(self) -> None:
super().__del__()
self._multi_processing.stop()
self._p_to_pm_src_port.join()
Raises
------
AssertionError
If the runtime of the Lava network was not created.
"""
# The csp channel is created by the runtime
if hasattr(self.in_port, 'external_pipe_csp_send_port'):
self._send(self.in_port.external_pipe_csp_send_port, data)
else:
raise AssertionError("The Runtime needs to be created before"
"calling <send>. Please use the method "
"<create_runtime> or <run> on your Lava"
" network before using <send>.")


@implements(proc=Injector, protocol=LoihiProtocol)
@requires(CPU)
class PyLoihiInjectorModel(PyLoihiProcessModel):
"""PyLoihiProcessModel for the Injector Process."""
in_port: PyInPort = LavaPyType(PyInPort.VEC_DENSE, float)
out_port: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)

def __init__(self, proc_params: dict) -> None:
super().__init__(proc_params=proc_params)

shape = self.proc_params["shape"]
channel_config = self.proc_params["channel_config"]
self._p_to_pm_dst_port = self.proc_params["p_to_pm_dst_port"]
self._p_to_pm_dst_port.start()

self._zeros = np.zeros(shape)

self._receive_when_empty = channel_config.get_receive_empty_function()
Expand All @@ -123,19 +109,16 @@ def __init__(self, proc_params: dict) -> None:

def run_spk(self) -> None:
self._zeros.fill(0)
elements_in_buffer = self._p_to_pm_dst_port._queue.qsize()
elements_in_buffer = self.in_port.csp_ports[-1]._queue.qsize()

if elements_in_buffer == 0:
data = self._receive_when_empty(
self._p_to_pm_dst_port,
self.in_port,
self._zeros)
else:
data = self._receive_when_not_empty(
self._p_to_pm_dst_port,
self.in_port,
self._zeros,
elements_in_buffer)

self.out_port.send(data)

def __del__(self) -> None:
self._p_to_pm_dst_port.join()
Loading

0 comments on commit 1cb36bd

Please sign in to comment.