Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shmem Ports & Channel refactor development #307

Merged
merged 39 commits into from
Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7461e2a
Fix python wrapper build for shmem port & channel
hexu33 Aug 12, 2022
642d813
Merge branch 'messaging_refactor_develop' of https://github.com/hexu3…
hexu33 Aug 12, 2022
b0aeb69
Fix build bugs
hexu33 Aug 12, 2022
d32bff0
Add ChannelFactory to createe channels
hexu33 Aug 12, 2022
542188d
Fix cpplint error
hexu33 Aug 12, 2022
fead41d
Add functions for ports
hexu33 Aug 16, 2022
5dbbf90
Change shmem ports functions name following Google style
hexu33 Aug 16, 2022
0d33260
Fix shmem port build issues
hexu33 Aug 16, 2022
5184a14
Fix build bugs & Add PortProxy class
hexu33 Aug 16, 2022
873c6cd
Use shared_ptr to replace original pointers
hexu33 Aug 16, 2022
634dced
Fix cpplint errors
hexu33 Aug 16, 2022
d7d3cfa
Add GetChannel wrapper & Make channel return port_proxy
hexu33 Aug 17, 2022
1af6d79
Merge branch 'messaging_refactor_develop' into messaging_refactor_dev…
hexu33 Aug 17, 2022
b501b90
Update CMakeLists.txt
hexu33 Aug 17, 2022
094adc9
Add test for channel& port python wrapper[Todo: bug fix]
hexu33 Aug 18, 2022
e61b780
Update channel testcase
hexu33 Aug 18, 2022
a9d0766
Update message_infrastructure_py_wrapper.cc
hexu33 Aug 18, 2022
fd859cf
Update Channel Type and fix lint error
hongdami Aug 18, 2022
c13cc6c
Wrapper Shemchannel to pypychannel.py
hongdami Aug 18, 2022
50d782a
Fix returning base class pointer bug
hexu33 Aug 19, 2022
08a5a3a
Add implementation for AbstractPort
hexu33 Aug 19, 2022
8d170fa
Fix bug of implementation for AbstractPort
hexu33 Aug 19, 2022
25d6dfa
Update XPort (#2)
killight98 Aug 19, 2022
ed4076a
Hongda dev (#3)
hongdami Aug 19, 2022
7148a72
Fix bug for py_wrapper
hexu33 Aug 19, 2022
399ee74
Fix python lint error
hongdami Aug 19, 2022
a5dc955
Merge branch 'messaging_refactor_develop' into messaging_refactor_dev…
hexu33 Aug 19, 2022
893915b
Merge branch 'messaging_refactor_develop' of https://github.com/hexu3…
hongdami Aug 19, 2022
2ed57ab
Fix test_xport.py bug
hexu33 Aug 19, 2022
4777c3b
Fix RecvPort Queue data struct
hexu33 Aug 19, 2022
e73b101
Fix cpplint error
hexu33 Aug 19, 2022
c3514c2
align py ports interface with message_infrastructure lib
killight98 Aug 19, 2022
a4ffe72
amend ports.py
killight98 Aug 19, 2022
3e2b7cf
Update Selector, TODO: change channel_broker
hongdami Aug 19, 2022
45d54ae
fix lint fail
killight98 Aug 19, 2022
798a8f6
Merge branch
hongdami Aug 19, 2022
193fe8a
Merge branch 'messaging_refactor_develop' of https://github.com/hexu3…
hongdami Aug 19, 2022
46761d0
fix cpplint error
hongdami Aug 19, 2022
ee556b8
fix import fails
killight98 Aug 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions src/lava/magma/compiler/builders/channel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
AbstractProcessModel
from lava.magma.compiler.builders. \
runtimeservice_builder import RuntimeServiceBuilder
from lava.magma.compiler.channels.interfaces import (
from message_infrastructure import (
Channel,
ChannelType,
ChannelTransferType,
get_channel_factory,
ShmemChannel,
SharedMemory
)
from lava.magma.compiler.channels.interfaces import ChannelType
from lava.magma.compiler.utils import PortInitializer
from message_infrastructure \
.message_infrastructure_interface import (MessageInfrastructureInterface)
Expand All @@ -29,7 +33,7 @@ class ChannelBuilderMp(AbstractChannelBuilder):
and multi processing backbone.
"""

channel_type: ChannelType
channel_type: ChannelTransferType
src_process: "AbstractProcess"
dst_process: "AbstractProcess"
src_port_initializer: PortInitializer
Expand All @@ -54,9 +58,15 @@ def build(
Exception
Can't build channel of type specified
"""
channel_class = messaging_infrastructure.channel_class(
channel_type=self.channel_type
)
channel_factory = get_channel_factory()
shm = messaging_infrastructure.smm.alloc_mem(
self.src_port_initializer.size)
return channel_factory.get_channel(ChannelTransferType.SHMEMCHANNEL,
shm,
self.src_port_initializer.d_type,
self.src_port_initializer.size,
self.src_port_initializer.shape,
self.src_port_initializer.name)
return channel_class(
messaging_infrastructure,
self.src_port_initializer.name,
Expand All @@ -73,7 +83,7 @@ class ServiceChannelBuilderMp(AbstractChannelBuilder):
as messaging and multi processing backbone.
"""

channel_type: ChannelType
channel_type: ChannelTransferType
src_process: ty.Union[RuntimeServiceBuilder,
ty.Type["AbstractProcessModel"]]
dst_process: ty.Union[RuntimeServiceBuilder,
Expand All @@ -99,9 +109,15 @@ def build(
Exception
Can't build channel of type specified
"""
channel_class = messaging_infrastructure.channel_class(
channel_type=self.channel_type
)
channel_factory = get_channel_factory()
shm = messaging_infrastructure.smm.alloc_mem(
self.port_initializer.size)
return channel_factory.get_channel(ChannelTransferType.SHMEMCHANNEL,
shm,
self.port_initializer.d_type,
self.port_initializer.size,
self.port_initializer.shape,
self.port_initializer.name)

channel_name: str = self.port_initializer.name
return channel_class(
Expand All @@ -120,7 +136,7 @@ class RuntimeChannelBuilderMp(AbstractChannelBuilder):
used as messaging and multi processing backbone.
"""

channel_type: ChannelType
channel_type: ChannelTransferType
src_process: ty.Union[RuntimeServiceBuilder, ty.Type["Runtime"]]
dst_process: ty.Union[RuntimeServiceBuilder, ty.Type["Runtime"]]
port_initializer: PortInitializer
Expand All @@ -144,9 +160,15 @@ def build(
Exception
Can't build channel of type specified
"""
channel_class = messaging_infrastructure.channel_class(
channel_type=self.channel_type
)
channel_factory = get_channel_factory()
shm = messaging_infrastructure.smm.alloc_mem(
self.port_initializer.size)
return channel_factory.get_channel(ChannelTransferType.SHMEMCHANNEL,
shm,
self.port_initializer.d_type,
self.port_initializer.size,
self.port_initializer.shape,
self.port_initializer.name)

channel_name: str = self.port_initializer.name
return channel_class(
Expand All @@ -165,7 +187,7 @@ class ChannelBuilderNx(AbstractChannelBuilder):
infrastructure.
"""

channel_type: ChannelType
channel_type: ChannelTransferType
src_process: "AbstractProcess"
dst_process: "AbstractProcess"
src_port_initializer: PortInitializer
Expand Down
53 changes: 29 additions & 24 deletions src/lava/magma/compiler/builders/py_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
import numpy as np
from lava.magma.compiler.builders.interfaces import AbstractProcessBuilder

from lava.magma.compiler.channels.interfaces import AbstractCspPort
from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort
from message_infrastructure import (
AbstractTransferPort,
RecvPort,
SendPort
)
from lava.magma.compiler.utils import (PortInitializer, VarInitializer,
VarPortInitializer)
from lava.magma.core.model.py.model import AbstractPyProcessModel
from lava.magma.core.model.py.ports import (AbstractPyIOPort,
IdentityTransformer, PyInPort,
PyOutPort, PyRefPort, PyVarPort,
VirtualPortTransformer)
from message_infrastructure.ports import (AbstractPyIOPort,
IdentityTransformer, PyInPort,
PyOutPort, PyRefPort, PyVarPort,
VirtualPortTransformer)
from lava.magma.core.model.py.type import LavaPyType


Expand Down Expand Up @@ -60,10 +63,11 @@ def __init__(
self.py_ports: ty.Dict[str, PortInitializer] = {}
self.ref_ports: ty.Dict[str, PortInitializer] = {}
self.var_ports: ty.Dict[str, VarPortInitializer] = {}
self.csp_ports: ty.Dict[str, ty.List[AbstractCspPort]] = {}
self._csp_port_map: ty.Dict[str, ty.Dict[str, AbstractCspPort]] = {}
self.csp_rs_send_port: ty.Dict[str, CspSendPort] = {}
self.csp_rs_recv_port: ty.Dict[str, CspRecvPort] = {}
self.csp_ports: ty.Dict[str, ty.List[AbstractTransferPort]] = {}
self._csp_port_map: ty.Dict[str,
ty.Dict[str, AbstractTransferPort]] = {}
self.csp_rs_send_port: ty.Dict[str, SendPort] = {}
self.csp_rs_recv_port: ty.Dict[str, RecvPort] = {}
self.proc_params = proc_params

def check_all_vars_and_ports_set(self):
Expand Down Expand Up @@ -165,13 +169,13 @@ def set_var_ports(self, var_ports: ty.List[VarPortInitializer]):
self._check_not_assigned_yet(self.var_ports, new_ports.keys(), "ports")
self.var_ports.update(new_ports)

def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
def set_csp_ports(self, csp_ports: ty.List[AbstractTransferPort]):
"""Appends the given list of CspPorts to the ProcessModel. Used by the
runtime to configure csp ports during initialization (_build_channels).

Parameters
----------
csp_ports : ty.List[AbstractCspPort]
csp_ports : ty.List[AbstractTransferPort]


Raises
Expand All @@ -181,7 +185,7 @@ def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
"""
new_ports = {}
for p in csp_ports:
new_ports.setdefault(p.name, []).extend(
new_ports.setdefault(p.name(), []).extend(
p if isinstance(p, list) else [p]
)

Expand All @@ -197,7 +201,8 @@ def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
else:
self.csp_ports[port_name] = new_ports[port_name]

def add_csp_port_mapping(self, py_port_id: str, csp_port: AbstractCspPort):
def add_csp_port_mapping(self, py_port_id: str,
csp_port: AbstractTransferPort):
"""Appends a mapping from a PyPort ID to a CSP port. This is used
to associate a CSP port in a PyPort with transformation functions
that implement the behavior of virtual ports.
Expand All @@ -212,10 +217,10 @@ def add_csp_port_mapping(self, py_port_id: str, csp_port: AbstractCspPort):
"""
# Add or update the mapping
self._csp_port_map.setdefault(
csp_port.name, {}
csp_port.name(), {}
).update({py_port_id: csp_port})

def set_rs_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
def set_rs_csp_ports(self, csp_ports: ty.List[AbstractTransferPort]):
"""Set RS CSP Ports

Parameters
Expand All @@ -224,10 +229,10 @@ def set_rs_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):

"""
for port in csp_ports:
if isinstance(port, CspSendPort):
self.csp_rs_send_port.update({port.name: port})
if isinstance(port, CspRecvPort):
self.csp_rs_recv_port.update({port.name: port})
if isinstance(port, SendPort):
self.csp_rs_send_port.update({port.name(): port})
if isinstance(port, RecvPort):
self.csp_rs_recv_port.update({port.name(): port})

def _get_lava_type(self, name: str) -> LavaPyType:
return getattr(self.proc_model, name)
Expand Down Expand Up @@ -302,9 +307,9 @@ def build(self):
if name in self.csp_ports:
csp_ports = self.csp_ports[name]
csp_recv = csp_ports[0] if isinstance(
csp_ports[0], CspRecvPort) else csp_ports[1]
csp_ports[0], RecvPort) else csp_ports[1]
csp_send = csp_ports[0] if isinstance(
csp_ports[0], CspSendPort) else csp_ports[1]
csp_ports[0], SendPort) else csp_ports[1]

transformer = VirtualPortTransformer(
self._csp_port_map[name],
Expand All @@ -329,9 +334,9 @@ def build(self):
if name in self.csp_ports:
csp_ports = self.csp_ports[name]
csp_recv = csp_ports[0] if isinstance(
csp_ports[0], CspRecvPort) else csp_ports[1]
csp_ports[0], RecvPort) else csp_ports[1]
csp_send = csp_ports[0] if isinstance(
csp_ports[0], CspSendPort) else csp_ports[1]
csp_ports[0], SendPort) else csp_ports[1]

transformer = VirtualPortTransformer(
self._csp_port_map[name],
Expand Down
44 changes: 24 additions & 20 deletions src/lava/magma/compiler/builders/runtimeservice_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
import logging
import typing as ty

from lava.magma.compiler.channels.interfaces import AbstractCspPort
from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort
from message_infrastructure import (
AbstractTransferPort,
RecvPort,
SendPort
)

from lava.magma.core.sync.protocol import AbstractSyncProtocol
from lava.magma.runtime.runtime_services.enums import LoihiVersion
from lava.magma.runtime.runtime_services.runtime_service import \
Expand Down Expand Up @@ -49,17 +53,17 @@ def __init__(
self.log.setLevel(loglevel)
self._runtime_service_id = runtime_service_id
self._model_ids: ty.List[int] = model_ids
self.csp_send_port: ty.Dict[str, CspSendPort] = {}
self.csp_recv_port: ty.Dict[str, CspRecvPort] = {}
self.csp_proc_send_port: ty.Dict[str, CspSendPort] = {}
self.csp_proc_recv_port: ty.Dict[str, CspRecvPort] = {}
self.csp_send_port: ty.Dict[str, SendPort] = {}
self.csp_recv_port: ty.Dict[str, RecvPort] = {}
self.csp_proc_send_port: ty.Dict[str, SendPort] = {}
self.csp_proc_recv_port: ty.Dict[str, RecvPort] = {}
self.loihi_version: ty.Type[LoihiVersion] = loihi_version

@property
def runtime_service_id(self):
return self._runtime_service_id

def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
def set_csp_ports(self, csp_ports: ty.List[AbstractTransferPort]):
"""Set CSP Ports

Parameters
Expand All @@ -68,12 +72,12 @@ def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):

"""
for port in csp_ports:
if isinstance(port, CspSendPort):
self.csp_send_port.update({port.name: port})
if isinstance(port, CspRecvPort):
self.csp_recv_port.update({port.name: port})
if isinstance(port, SendPort):
self.csp_send_port.update({port.name(): port})
if isinstance(port, RecvPort):
self.csp_recv_port.update({port.name(): port})

def set_csp_proc_ports(self, csp_ports: ty.List[AbstractCspPort]):
def set_csp_proc_ports(self, csp_ports: ty.List[AbstractTransferPort]):
"""Set CSP Process Ports

Parameters
Expand All @@ -82,10 +86,10 @@ def set_csp_proc_ports(self, csp_ports: ty.List[AbstractCspPort]):

"""
for port in csp_ports:
if isinstance(port, CspSendPort):
self.csp_proc_send_port.update({port.name: port})
if isinstance(port, CspRecvPort):
self.csp_proc_recv_port.update({port.name: port})
if isinstance(port, SendPort):
self.csp_proc_send_port.update({port.name(): port})
if isinstance(port, RecvPort):
self.csp_proc_recv_port.update({port.name(): port})

def build(self) -> AbstractRuntimeService:
"""Build the runtime service
Expand Down Expand Up @@ -115,21 +119,21 @@ def build(self) -> AbstractRuntimeService:

if not nxsdk_rts:
for port in self.csp_proc_send_port.values():
if "service_to_process" in port.name:
if "service_to_process" in port.name():
rs.service_to_process.append(port)

for port in self.csp_proc_recv_port.values():
if "process_to_service" in port.name:
if "process_to_service" in port.name():
rs.process_to_service.append(port)

self.log.debug("Setup 'RuntimeService <--> Rrocess; ports")

for port in self.csp_send_port.values():
if "service_to_runtime" in port.name:
if "service_to_runtime" in port.name():
rs.service_to_runtime = port

for port in self.csp_recv_port.values():
if "runtime_to_service" in port.name:
if "runtime_to_service" in port.name():
rs.runtime_to_service = port

self.log.debug("Setup 'Runtime <--> RuntimeService' ports")
Expand Down
1 change: 1 addition & 0 deletions src/lava/magma/compiler/channels/pypychannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from message_infrastructure.message_infrastructure_interface \
import (
MessageInfrastructureInterface)
"""Depricated"""


@dataclass
Expand Down
10 changes: 6 additions & 4 deletions src/lava/magma/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class AbstractNcProcessModel:
from lava.magma.runtime.runtime import Runtime
from lava.magma.runtime.runtime_services.enums import LoihiVersion

from message_infrastructure import ChannelTransferType


class Compiler:
"""Lava processes Compiler, called from any process in a process network.
Expand Down Expand Up @@ -728,7 +730,7 @@ def _create_sync_channel_builders(
sync_channel_builders: ty.List[AbstractChannelBuilder] = []
for sync_domain in rsb:
runtime_to_service = RuntimeChannelBuilderMp(
ChannelType.PyPy,
ChannelTransferType.SHMEMCHANNEL,
Runtime,
rsb[sync_domain],
self._create_mgmt_port_initializer(
Expand All @@ -738,7 +740,7 @@ def _create_sync_channel_builders(
sync_channel_builders.append(runtime_to_service)

service_to_runtime = RuntimeChannelBuilderMp(
ChannelType.PyPy,
ChannelTransferType.SHMEMCHANNEL,
rsb[sync_domain],
Runtime,
self._create_mgmt_port_initializer(
Expand All @@ -750,7 +752,7 @@ def _create_sync_channel_builders(
for process in sync_domain.processes:
if issubclass(process.model_class, AbstractPyProcessModel):
service_to_process = ServiceChannelBuilderMp(
ChannelType.PyPy,
ChannelTransferType.SHMEMCHANNEL,
rsb[sync_domain],
process,
self._create_mgmt_port_initializer(
Expand All @@ -760,7 +762,7 @@ def _create_sync_channel_builders(
sync_channel_builders.append(service_to_process)

process_to_service = ServiceChannelBuilderMp(
ChannelType.PyPy,
ChannelTransferType.SHMEMCHANNEL,
process,
rsb[sync_domain],
self._create_mgmt_port_initializer(
Expand Down
Loading