Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable RefPort connections #24

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions lava/magma/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from lava.magma.core.model.nc.model import AbstractNcProcessModel
from lava.magma.core.model.py.model import AbstractPyProcessModel
from lava.magma.core.model.sub.model import AbstractSubProcessModel
from lava.magma.core.process.ports.ports import AbstractPort, VarPort
from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.resources import CPU, NeuroCore
from lava.magma.core.run_configs import RunConfig
Expand All @@ -50,7 +51,6 @@ def __init__(self, compile_cfg: ty.Optional[ty.Dict[str, ty.Any]] = None):
self._compile_config.update(compile_cfg)

# ToDo: (AW) Clean this up by avoiding redundant search paths
# ToDo: (AW) @PP Please include RefPorts/VarPorts in connection tracing
def _find_processes(self,
proc: AbstractProcess,
seen_procs: ty.List[AbstractProcess] = None) \
Expand All @@ -70,14 +70,14 @@ def _find_processes(self,
new_list: ty.List[AbstractProcess] = []

# add processes connecting to the main process
for in_port in proc.in_ports:
for in_port in proc.in_ports.members + proc.var_ports.members:
for con in in_port.in_connections:
new_list.append(con.process)
for con in in_port.out_connections:
new_list.append(con.process)

# add processes connecting from the main process
for out_port in proc.out_ports:
for out_port in proc.out_ports.members + proc.ref_ports.members:
for con in out_port.in_connections:
new_list.append(con.process)
for con in out_port.out_connections:
Expand Down Expand Up @@ -273,10 +273,10 @@ def _compile_proc_models(
v = [VarInitializer(v.name, v.shape, v.init, v.id)
for v in p.vars]
ports = (list(p.in_ports) + list(p.out_ports)
+ list(p.ref_ports))
+ list(p.ref_ports) + list(p.var_ports))
ports = [PortInitializer(pt.name,
pt.shape,
getattr(pm, pt.name).d_type,
self._get_port_dtype(pt, pm),
pt.__class__.__name__,
pp_ch_size) for pt in ports]
# Assigns initializers to builder
Expand Down Expand Up @@ -497,6 +497,23 @@ def _get_channel_type(src: ty.Type[AbstractProcessModel],
f"'({src.__name__}, {dst.__name__})' yet."
)

@staticmethod
def _get_port_dtype(port: AbstractPort,
proc_model: ty.Type[AbstractProcessModel]) -> type:
"""Returns the type of a port, as specified in the corresponding
ProcessModel."""

# In-, Out-, Ref- and explicit VarPorts
if hasattr(proc_model, port.name):
return getattr(proc_model, port.name).d_type
# Implicitly created VarPorts
elif isinstance(port, VarPort):
return getattr(proc_model, port.var.name).d_type
# Port has different name in Process and ProcessModel
else:
raise AssertionError("Port {!r} not found in "
"ProcessModel {!r}".format(port, proc_model))

# ToDo: (AW) Fix hard-coded hacks in this method and extend to other
# channel types
def _create_channel_builders(self, proc_map: PROC_MAP) \
Expand Down Expand Up @@ -526,7 +543,7 @@ def _create_channel_builders(self, proc_map: PROC_MAP) \
# Find destination ports for each source port
for src_pt in src_ports:
# Create PortInitializer for source port
src_pt_dtype = getattr(src_pm, src_pt.name).d_type
src_pt_dtype = self._get_port_dtype(src_pt, src_pm)
src_pt_init = PortInitializer(
src_pt.name, src_pt.shape, src_pt_dtype,
src_pt.__class__.__name__, ch_size)
Expand All @@ -541,7 +558,7 @@ def _create_channel_builders(self, proc_map: PROC_MAP) \
# Find appropriate channel type
ch_type = self._get_channel_type(src_pm, dst_pm)
# Create PortInitializer for destination port
dst_pt_d_type = getattr(dst_pm, dst_pt.name).d_type
dst_pt_d_type = self._get_port_dtype(dst_pt, dst_pm)
dst_pt_init = PortInitializer(
dst_pt.name, dst_pt.shape, dst_pt_d_type,
dst_pt.__class__.__name__, ch_size)
Expand Down
4 changes: 1 addition & 3 deletions lava/magma/compiler/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import typing as ty
from dataclasses import dataclass

import numpy as np


@dataclass
class VarInitializer:
Expand All @@ -16,6 +14,6 @@ class VarInitializer:
class PortInitializer:
name: str
shape: ty.Tuple[int, ...]
d_type: ty.Type[np.intc]
d_type: type
port_type: str
size: int
7 changes: 6 additions & 1 deletion lava/magma/core/process/ports/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,15 @@ def connect_var(self, variables: ty.Union[Var, ty.List[Var]]):
# Create a VarPort to wrap Var
vp = VarPort(v)
# Propagate name and parent process of Var to VarPort
vp.name = v.name + "_port"
vp.name = "_" + v.name + "_implicit_port"
if v.process is not None:
# Only assign when parent process is already assigned
vp.process = v.process
# VarPort Name could shadow existing attribute
if hasattr(v.process, vp.name):
raise AssertionError(
"Name of implicit VarPort might conflict"
" with existing attribute.")
var_ports.append(vp)
# Connect RefPort to VarPorts that wrap Vars
self.connect(var_ports)
Expand Down
93 changes: 91 additions & 2 deletions tests/magma/compiler/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from lava.magma.core.sync.domain import SyncDomain
from lava.magma.core.sync.protocol import AbstractSyncProtocol
from lava.magma.core.sync.protocols.async_protocol import AsyncProtocol
from lava.magma.core.process.ports.ports import InPort, OutPort
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.process.ports.ports import (
InPort, OutPort, RefPort, VarPort)
from lava.magma.core.model.py.ports import PyInPort, PyOutPort, PyRefPort
from lava.magma.core.run_configs import RunConfig
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.process.variable import Var, VarServer
Expand All @@ -29,6 +30,7 @@ def __init__(self, **kwargs):
# Use ReduceOp to allow for multiple input connections
self.inp = InPort(shape=(1,), reduce_op=ReduceSum)
self.out = OutPort(shape=(1,))
self.ref = RefPort(shape=(10,))


# Another minimal process (does not matter that it's identical to ProcA)
Expand All @@ -39,6 +41,7 @@ def __init__(self, **kwargs):
self.inp = InPort(shape=(1,), reduce_op=ReduceSum)
self.out = OutPort(shape=(1,))
self.some_var = Var((10,), init=10)
self.var_port = VarPort(self.some_var)


# Another minimal process (does not matter that it's identical to ProcA)
Expand Down Expand Up @@ -74,6 +77,7 @@ def runtime_service(self):
class PyProcModelA(AbstractPyProcessModel):
inp: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int)
out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int)
ref: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int)

def run(self):
pass
Expand All @@ -86,6 +90,7 @@ class PyProcModelB(AbstractPyProcessModel):
inp: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int)
out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int)
some_var: int = LavaPyType(int, int)
var_port: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int)

def run(self):
pass
Expand Down Expand Up @@ -262,6 +267,30 @@ def test_find_process_circular(self):
self.assertEqual(set(procs4), all_procs)
self.assertEqual(set(procs6), all_procs)

def test_find_process_ref_ports(self):
"""Checks finding all processes for RefPort connection.
[p1 -> ref/var -> p2 -> in/out -> p3]"""

# Create processes
p1, p2, p3 = ProcA(), ProcB(), ProcC()

# Connect p1 (RefPort) with p2 (VarPort)
p1.ref.connect(p2.var_port)
# Connect p2 (OutPort) with p3 (InPort)
p2.out.connect(p3.inp)

# Regardless where we start searching...
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
c = Compiler()
procs1 = c._find_processes(p1)
procs2 = c._find_processes(p2)
procs3 = c._find_processes(p3)

# ...we will find all of them
all_procs = {p1, p2, p3}
self.assertEqual(set(procs1), all_procs)
self.assertEqual(set(procs2), all_procs)
self.assertEqual(set(procs3), all_procs)

def test_find_proc_models(self):
"""Check finding of ProcModels that implement a Process."""

Expand Down Expand Up @@ -658,6 +687,66 @@ def test_create_channel_builders_hierarchical_process(self):
self.assertEqual(chb[0].src_process, p.procs.proc1)
self.assertEqual(chb[0].dst_process, p.procs.proc2)

def test_create_channel_builders_ref_ports(self):
"""Checks creation of channel builders when a process is connected
using a RefPort to another process VarPort."""

# Create a process with a RefPort (source)
src = ProcA()

# Create a process with a var (destination)
dst = ProcB()

# Connect them using RefPort and VarPort
src.ref.connect(dst.var_port)

# Create a manual proc_map
proc_map = {
src: PyProcModelA,
dst: PyProcModelB
}

# Create channel builders
c = Compiler()
cbs = c._create_channel_builders(proc_map)

# This should result in 1 channel builder
from lava.magma.compiler.builder import ChannelBuilderMp
self.assertEqual(len(cbs), 1)
self.assertIsInstance(cbs[0], ChannelBuilderMp)
self.assertEqual(cbs[0].src_process, src)
self.assertEqual(cbs[0].dst_process, dst)

def test_create_channel_builders_ref_ports_implicit(self):
"""Checks creation of channel builders when a process is connected
using a RefPort to another process Var (implicit VarPort)."""

# Create a process with a RefPort (source)
src = ProcA()

# Create a process with a var (destination)
dst = ProcB()

# Connect them using RefPort and Var (creates implicitly a VarPort)
src.ref.connect_var(dst.some_var)

# Create a manual proc_map
proc_map = {
src: PyProcModelA,
dst: PyProcModelB
}

# Create channel builders
c = Compiler()
cbs = c._create_channel_builders(proc_map)

# This should result in 1 channel builder
from lava.magma.compiler.builder import ChannelBuilderMp
self.assertEqual(len(cbs), 1)
self.assertIsInstance(cbs[0], ChannelBuilderMp)
self.assertEqual(cbs[0].src_process, src)
self.assertEqual(cbs[0].dst_process, dst)

# ToDo: (AW) @YS/@JM Please fix unit test by passing run_srv_builders to
# _create_exec_vars when ready
def test_create_py_exec_vars(self):
Expand Down
65 changes: 62 additions & 3 deletions tests/magma/core/process/test_ports.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
import unittest

from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.ports.ports import (
InPort,
OutPort,
Expand Down Expand Up @@ -245,9 +245,68 @@ def test_connect_RefPort_to_Var(self):

# In this case, the VarPort inherits its name and parent process from
# the Var it wraps
self.assertEqual(vp.name, v.name + "_port")
self.assertEqual(vp.name, "_" + v.name + "_implicit_port")
# (We can't check for the same parent process here because it has not
# been assigned ot the Var yet)
# been assigned to the Var yet)

def test_connect_RefPort_to_Var_process(self):
"""Checks connecting RefPort implicitly to Var, with registered
processes."""

# Create a mock parent process
class VarProcess(AbstractProcess):
...

# Create a Var and RefPort...
v = Var((1, 2, 3))
rp = RefPort((1, 2, 3))

# ...register a process for the Var
v.process = VarProcess()

# ...then connect them directly via connect_var(..)
rp.connect_var(v)

# This has the same effect as connecting a RefPort explicitly via a
# VarPort to a Var...
self.assertEqual(rp.get_dst_vars(), [v])
# ... but still creates a VarPort implicitly
vp = rp.get_dst_ports()[0]
self.assertIsInstance(vp, VarPort)
# ... which wraps the original Var
self.assertEqual(vp.var, v)

# In this case, the VarPort inherits its name and parent process from
# the Var it wraps
self.assertEqual(vp.name, "_" + v.name + "_implicit_port")
self.assertEqual(vp.process, v.process)

def test_connect_RefPort_to_Var_process_conflict(self):
"""Checks connecting RefPort implicitly to Var, with registered
processes and conflicting names. -> AssertionError"""

# Create a mock parent process
class VarProcess(AbstractProcess):
# Attribute is named like our implicit VarPort after creation
_existing_attr_implicit_port = None

# Create a Var and RefPort...
v = Var((1, 2, 3))
rp = RefPort((1, 2, 3))

# Create a Var and RefPort...
v = Var((1, 2, 3))
rp = RefPort((1, 2, 3))

# ...register a process for the Var and name it so it conflicts with
# the attribute ov VarProcess (very unlikely to happen)
v.process = VarProcess()
v.name = "existing_attr"

# ... and connect it directly via connect_var(..)
# The naming conflict should raise an AssertionError
with self.assertRaises(AssertionError):
rp.connect_var(v)

def test_connect_RefPort_to_many_Vars(self):
"""Checks that RefPort can be connected to many Vars."""
Expand Down