Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enablement of RefPort to Var/VarPort connections #46

Merged
merged 11 commits into from
Nov 17, 2021
Merged
101 changes: 84 additions & 17 deletions src/lava/magma/compiler/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@

import numpy as np
from dataclasses import dataclass

from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort
from lava.magma.core.model.py.model import AbstractPyProcessModel
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.compiler.utils import VarInitializer, PortInitializer
from lava.magma.core.model.py.ports import AbstractPyPort, \
PyInPort, PyOutPort, PyRefPort
from lava.magma.compiler.utils import VarInitializer, PortInitializer, \
VarPortInitializer
from lava.magma.core.model.py.ports import (
AbstractPyPort,
PyInPort,
PyOutPort,
PyRefPort, PyVarPort,
)
from lava.magma.compiler.channels.interfaces import AbstractCspPort, Channel, \
ChannelType

Expand Down Expand Up @@ -91,6 +95,8 @@ def __init__(
self._model_id = model_id
self.vars: ty.Dict[str, VarInitializer] = {}
self.py_ports: ty.Dict[str, PortInitializer] = {}
self.ref_ports: ty.Dict[str, PortInitializer] = {}
self.var_ports: ty.Dict[str, VarPortInitializer] = {}
self.csp_ports: ty.Dict[str, ty.List[AbstractCspPort]] = {}
self.csp_rs_send_port: ty.Dict[str, CspSendPort] = {}
self.csp_rs_recv_port: ty.Dict[str, CspRecvPort] = {}
Expand Down Expand Up @@ -167,6 +173,8 @@ def check_all_vars_and_ports_set(self):
if (
attr_name not in self.vars
and attr_name not in self.py_ports
and attr_name not in self.ref_ports
and attr_name not in self.var_ports
):
raise AssertionError(
f"No LavaPyType '{attr_name}' found in ProcModel "
Expand Down Expand Up @@ -235,6 +243,29 @@ def set_py_ports(self, py_ports: ty.List[PortInitializer], check=True):
self._check_not_assigned_yet(self.py_ports, new_ports.keys(), "ports")
self.py_ports.update(new_ports)

def set_ref_ports(self, ref_ports: ty.List[PortInitializer]):
"""Set py_ports
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
ref_ports : ty.List[PortInitializer]
"""
self._check_members_exist(ref_ports, "Port")
new_ports = {p.name: p for p in ref_ports}
self._check_not_assigned_yet(self.ref_ports, new_ports.keys(), "ports")
self.ref_ports.update(new_ports)

def set_var_ports(self, var_ports: ty.List[VarPortInitializer]):
"""Set var_ports
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
var_ports : ty.List[VarPortInitializer]
"""
new_ports = {p.name: p for p in var_ports}
self._check_not_assigned_yet(self.var_ports, new_ports.keys(), "ports")
self.var_ports.update(new_ports)

def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
"""Set CSP Ports

Expand All @@ -253,22 +284,18 @@ def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
new_ports.setdefault(p.name, []).extend(
p if isinstance(p, list) else [p]
)
self._check_not_assigned_yet(
self.csp_ports, new_ports.keys(), "csp_ports"
)

# Check that there's a PyPort for each new CspPort
proc_name = self.proc_model.implements_process.__name__
for port_name in new_ports:
if not hasattr(self.proc_model, port_name):
raise AssertionError(
"PyProcessModel '{}' has \
no port named '{}'.".format(
proc_name, port_name
)
)
# Set new CspPorts
for key, ports in new_ports.items():
self.csp_ports.setdefault(key, []).extend(ports)
raise AssertionError("PyProcessModel '{}' has \
no port named '{}'.".format(proc_name, port_name))
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved

if port_name in self.csp_ports:
self.csp_ports[port_name].extend(new_ports[port_name])
else:
self.csp_ports[port_name] = new_ports[port_name]

def set_rs_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
"""Set RS CSP Ports
Expand Down Expand Up @@ -326,13 +353,53 @@ def build(self):
csp_ports = self.csp_ports[name]
if not isinstance(csp_ports, list):
csp_ports = [csp_ports]
port = port_cls(pm, csp_ports, p.shape, lt.d_type)
port = port_cls(csp_ports, pm, p.shape, lt.d_type)

# Create dynamic PyPort attribute on ProcModel
setattr(pm, name, port)
# Create private attribute for port precision
# setattr(pm, "_" + name + "_p", lt.precision)

# Initialize RefPorts
for name, p in self.ref_ports.items():
# Build PyPort
lt = self._get_lava_type(name)
port_cls = ty.cast(ty.Type[PyRefPort], lt.cls)
csp_recv = None
csp_send = None
if name in self.csp_ports:
csp_ports = self.csp_ports[name]
csp_recv = csp_ports[0] if isinstance(
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
csp_ports[0], CspRecvPort) else csp_ports[1]
csp_send = csp_ports[0] if isinstance(
csp_ports[0], CspSendPort) else csp_ports[1]

port = port_cls(csp_send, csp_recv, pm, p.shape, lt.d_type)

# Create dynamic PyPort attribute on ProcModel
setattr(pm, name, port)

# Initialize VarPorts
for name, p in self.var_ports.items():
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
# Build PyPort
if p.port_cls is None:
# VarPort is not connected
continue
port_cls = ty.cast(ty.Type[PyVarPort], p.port_cls)
csp_recv = None
csp_send = None
if name in self.csp_ports:
csp_ports = self.csp_ports[name]
csp_recv = csp_ports[0] if isinstance(
csp_ports[0], CspRecvPort) else csp_ports[1]
csp_send = csp_ports[0] if isinstance(
csp_ports[0], CspSendPort) else csp_ports[1]
port = port_cls(
p.var_name, csp_send, csp_recv, pm, p.shape, p.d_type)

# Create dynamic PyPort attribute on ProcModel
setattr(pm, name, port)

for port in self.csp_rs_recv_port.values():
if "service_to_process_cmd" in port.name:
pm.service_to_process_cmd = port
Expand Down
103 changes: 94 additions & 9 deletions src/lava/magma/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@
from lava.magma.compiler.channels.interfaces import ChannelType
from lava.magma.compiler.executable import Executable
from lava.magma.compiler.node import NodeConfig, Node
from lava.magma.compiler.utils import VarInitializer, PortInitializer
from lava.magma.compiler.utils import VarInitializer, PortInitializer, \
VarPortInitializer
from lava.magma.core import resources
from lava.magma.core.model.c.model import AbstractCProcessModel
from lava.magma.core.model.model import AbstractProcessModel
from lava.magma.core.model.nc.model import AbstractNcProcessModel
from lava.magma.core.model.py.model import AbstractPyProcessModel
from lava.magma.core.model.py.ports import PyVarPort, RefVarTypeMapping
from lava.magma.core.model.sub.model import AbstractSubProcessModel
from lava.magma.core.process.ports.ports import AbstractPort, VarPort, \
ImplicitVarPort
from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.resources import CPU, NeuroCore
from lava.magma.core.run_configs import RunConfig
Expand All @@ -49,7 +53,6 @@ def __init__(self, compile_cfg: ty.Optional[ty.Dict[str, ty.Any]] = None):
self._compile_config.update(compile_cfg)

# ToDo: (AW) Clean this up by avoiding redundant search paths
# ToDo: (AW) @PP Please include RefPorts/VarPorts in connection tracing
def _find_processes(self,
proc: AbstractProcess,
seen_procs: ty.List[AbstractProcess] = None) \
Expand All @@ -69,14 +72,14 @@ def _find_processes(self,
new_list: ty.List[AbstractProcess] = []

# add processes connecting to the main process
for in_port in proc.in_ports:
for in_port in proc.in_ports.members + proc.var_ports.members:
for con in in_port.in_connections:
new_list.append(con.process)
for con in in_port.out_connections:
new_list.append(con.process)

# add processes connecting from the main process
for out_port in proc.out_ports:
for out_port in proc.out_ports.members + proc.ref_ports.members:
for con in out_port.in_connections:
new_list.append(con.process)
for con in out_port.out_connections:
Expand Down Expand Up @@ -251,6 +254,35 @@ def _group_proc_by_model(proc_map: PROC_MAP) \

return grouped_models

# TODO: (PP) This currently only works for PyPorts - needs general solution
# TODO: (PP) Currently does not support 1:many/many:1 connections
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
@staticmethod
def _map_var_port_class(port: VarPort,
proc_groups: ty.Dict[ty.Type[AbstractProcessModel],
ty.List[AbstractProcess]]):
"""Derives the port class of a given VarPort from its source RefPort."""
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved

# Get the source RefPort of the VarPort
rp = port.get_src_ports()
if len(rp) > 0:
rp = rp[0]
else:
# VarPort is not connect, hence there is no LavaType
return None

# Get the ProcessModel of the source RefPort
r_pm = None
for pm in proc_groups:
if rp.process in proc_groups[pm]:
r_pm = pm

# Get the LavaType of the RefPort from its ProcessModel
lt = getattr(r_pm, rp.name)

# Return mapping of the RefPort class to VarPort class
return RefVarTypeMapping.get(lt.cls)

# TODO: (PP) possible shorten creation of PortInitializers
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
def _compile_proc_models(
self,
proc_groups: ty.Dict[ty.Type[AbstractProcessModel],
Expand All @@ -271,16 +303,42 @@ def _compile_proc_models(
# and Ports
v = [VarInitializer(v.name, v.shape, v.init, v.id)
for v in p.vars]
ports = (list(p.in_ports) + list(p.out_ports)
+ list(p.ref_ports))
ports = (list(p.in_ports) + list(p.out_ports))
ports = [PortInitializer(pt.name,
pt.shape,
getattr(pm, pt.name).d_type,
self._get_port_dtype(pt, pm),
pt.__class__.__name__,
pp_ch_size) for pt in ports]
# Create RefPort (also use PortInitializers)
ref_ports = list(p.ref_ports)
ref_ports = [
PortInitializer(pt.name,
pt.shape,
self._get_port_dtype(pt, pm),
pt.__class__.__name__,
pp_ch_size) for pt in ref_ports]
# Create VarPortInitializers (contain also the Var name)
var_ports = []
for pt in list(p.var_ports):
var_ports.append(
VarPortInitializer(
pt.name,
pt.shape,
pt.var.name,
self._get_port_dtype(pt, pm),
pt.__class__.__name__,
pp_ch_size,
self._map_var_port_class(pt, proc_groups)))

# Set implicit VarPorts as attribute to ProcessModel
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(pt, ImplicitVarPort):
setattr(pm, pt.name, pt)

# Assigns initializers to builder
b.set_variables(v)
b.set_py_ports(ports)
b.set_ref_ports(ref_ports)
b.set_var_ports(var_ports)
b.check_all_vars_and_ports_set()
py_builders[p] = b
elif issubclass(pm, AbstractCProcessModel):
Expand Down Expand Up @@ -496,6 +554,26 @@ def _get_channel_type(src: ty.Type[AbstractProcessModel],
f"'({src.__name__}, {dst.__name__})' yet."
)

@staticmethod
def _get_port_dtype(port: AbstractPort,
proc_model: ty.Type[AbstractProcessModel]) -> type:
"""Returns the type of a port, as specified in the corresponding
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
ProcessModel."""

# In-, Out-, Ref- and explicit VarPorts
if hasattr(proc_model, port.name):
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
# Handle VarPorts (use dtype of corresponding Var)
if isinstance(port, VarPort):
return getattr(proc_model, port.var.name).d_type
return getattr(proc_model, port.name).d_type
# Implicitly created VarPorts
elif isinstance(port, ImplicitVarPort):
return getattr(proc_model, port.var.name).d_type
# Port has different name in Process and ProcessModel
else:
raise AssertionError("Port {!r} not found in "
"ProcessModel {!r}".format(port, proc_model))

# ToDo: (AW) Fix hard-coded hacks in this method and extend to other
# channel types
def _create_channel_builders(self, proc_map: PROC_MAP) \
Expand Down Expand Up @@ -525,7 +603,7 @@ def _create_channel_builders(self, proc_map: PROC_MAP) \
# Find destination ports for each source port
for src_pt in src_ports:
# Create PortInitializer for source port
src_pt_dtype = getattr(src_pm, src_pt.name).d_type
src_pt_dtype = self._get_port_dtype(src_pt, src_pm)
src_pt_init = PortInitializer(
src_pt.name, src_pt.shape, src_pt_dtype,
src_pt.__class__.__name__, ch_size)
Expand All @@ -540,14 +618,21 @@ def _create_channel_builders(self, proc_map: PROC_MAP) \
# Find appropriate channel type
ch_type = self._get_channel_type(src_pm, dst_pm)
# Create PortInitializer for destination port
dst_pt_d_type = getattr(dst_pm, dst_pt.name).d_type
dst_pt_d_type = self._get_port_dtype(dst_pt, dst_pm)
dst_pt_init = PortInitializer(
dst_pt.name, dst_pt.shape, dst_pt_d_type,
dst_pt.__class__.__name__, ch_size)
# Create new channel builder
chb = ChannelBuilderMp(
ch_type, src_p, dst_p, src_pt_init, dst_pt_init)
channel_builders.append(chb)
# Create additional channel builder for every VarPort
if isinstance(dst_pt, VarPort):
# RefPort to VarPort connections need channels for
# read and write
rv_chb = ChannelBuilderMp(
ch_type, dst_p, src_p, dst_pt_init, src_pt_init)
channel_builders.append(rv_chb)

return channel_builders

Expand Down
28 changes: 25 additions & 3 deletions src/lava/magma/compiler/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import typing as ty
from dataclasses import dataclass

import numpy as np


@dataclass
class VarInitializer:
Expand All @@ -16,6 +14,30 @@ class VarInitializer:
class PortInitializer:
name: str
shape: ty.Tuple[int, ...]
d_type: ty.Type[np.intc]
d_type: type
port_type: str
size: int


# check if can be a subclass of PortInitializer
@dataclass
class VarPortInitializer:
name: str
shape: ty.Tuple[int, ...]
var_name: str
d_type: type
port_type: str
size: int


# check if can be a subclass of PortInitializer
@dataclass
class VarPortInitializer:
name: str
shape: ty.Tuple[int, ...]
var_name: str
d_type: type
port_type: str
size: int
port_cls: type
port_cls: type
Loading