Skip to content

Commit

Permalink
Virtual ports between RefPorts and VarPorts (lava-nc#195)
Browse files Browse the repository at this point in the history
* permute initial implementation

Signed-off-by: bamsumit <[email protected]>

* Tests for permute ports

* Process property of virtual ports no longer returns None

Signed-off-by: Mathis Richter <[email protected]>

* Added initial run-unittest for flatten() from issue lava-nc#163

Signed-off-by: Mathis Richter <[email protected]>

* User-level API for TransposePort with unit tests

Signed-off-by: Mathis Richter <[email protected]>

* Fixed typo

Signed-off-by: Mathis Richter <[email protected]>

* Unit tests for flatten() and concat_with()

Signed-off-by: Mathis Richter <[email protected]>

* Unit tests for virtual ports in Processes that are executed (wip)

Signed-off-by: Mathis Richter <[email protected]>

* Preliminary implementation of virtual ports between OutPort and InPort (wip)

Signed-off-by: Mathis Richter <[email protected]>

* Fixing unit tests after merge

Signed-off-by: Mathis Richter <[email protected]>

* Added support for virtual ports between an OutPort and InPort of two hierarchical Processes

Signed-off-by: Mathis Richter <[email protected]>

* Clean up, exceptions, and generic unit tests for virtual port topologies

Signed-off-by: Mathis Richter <[email protected]>

* Fixed linter issues

Signed-off-by: Mathis Richter <[email protected]>

* Raising an exception when executing ConcatPort

Signed-off-by: Mathis Richter <[email protected]>

* Unit tests for virtual ports between OutPorts and InPorts in hierarchical Processes.

Signed-off-by: Mathis Richter <[email protected]>

* RefPort writing to an explicit VarPort via a virtual port.

Signed-off-by: Mathis Richter <[email protected]>

* RefPort reading from an explicit VarPort via a virtual port.

Signed-off-by: Mathis Richter <[email protected]>

* Fixed linter error

Signed-off-by: Mathis Richter <[email protected]>

* Unit tests for virtual ports between RefPorts and VarPorts in hierarchical Processes.

Signed-off-by: Mathis Richter <[email protected]>

* Fixed a docstring

Signed-off-by: Mathis Richter <[email protected]>

* Added docstrings to methods get_transform_func_fwd/bwd.

Signed-off-by: Mathis Richter <[email protected]>

Co-authored-by: bamsumit <[email protected]>
Co-authored-by: Marcus G K Williams <[email protected]>
  • Loading branch information
3 people committed Mar 2, 2022
1 parent abfc914 commit 1aa726b
Show file tree
Hide file tree
Showing 7 changed files with 858 additions and 112 deletions.
6 changes: 4 additions & 2 deletions src/lava/magma/compiler/builders/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,8 @@ def build(self):
csp_send = csp_ports[0] if isinstance(
csp_ports[0], CspSendPort) else csp_ports[1]

port = port_cls(csp_send, csp_recv, pm, p.shape, lt.d_type)
port = port_cls(csp_send, csp_recv, pm, p.shape, lt.d_type,
p.transform_funcs)

# Create dynamic RefPort attribute on ProcModel
setattr(pm, name, port)
Expand All @@ -422,7 +423,8 @@ def build(self):
csp_send = csp_ports[0] if isinstance(
csp_ports[0], CspSendPort) else csp_ports[1]
port = port_cls(
p.var_name, csp_send, csp_recv, pm, p.shape, p.d_type)
p.var_name, csp_send, csp_recv, pm, p.shape, p.d_type,
p.transform_funcs)

# Create dynamic VarPort attribute on ProcModel
setattr(pm, name, port)
Expand Down
49 changes: 31 additions & 18 deletions src/lava/magma/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,40 +352,53 @@ def _compile_proc_models(
for pt in (list(p.in_ports) + list(p.out_ports)):
# For all InPorts that receive input from
# virtual ports...
transform_funcs = None
transform_funcs = []
if isinstance(pt, InPort):
# ... extract a function pointer to the
# transformation function of each virtual port.
transform_funcs = \
[vp.get_transform_func()
[vp.get_transform_func_fwd()
for vp in pt.get_incoming_virtual_ports()]

pi = PortInitializer(pt.name,
pt.shape,
self._get_port_dtype(pt, pm),
pt.__class__.__name__,
pp_ch_size,
transform_funcs)
ports.append(pi)

# Create RefPort (also use PortInitializers)
ref_ports = list(p.ref_ports)
ref_ports = [
PortInitializer(pt.name,
pt.shape,
self._get_port_dtype(pt, pm),
pt.__class__.__name__,
pp_ch_size) for pt in ref_ports]
ref_ports = []
for pt in list(p.ref_ports):
transform_funcs = \
[vp.get_transform_func_bwd()
for vp in pt.get_outgoing_virtual_ports()]

pi = PortInitializer(pt.name,
pt.shape,
self._get_port_dtype(pt, pm),
pt.__class__.__name__,
pp_ch_size,
transform_funcs)
ref_ports.append(pi)

# Create VarPortInitializers (contain also the Var name)
var_ports = []
for pt in list(p.var_ports):
var_ports.append(
VarPortInitializer(
pt.name,
pt.shape,
pt.var.name,
self._get_port_dtype(pt, pm),
pt.__class__.__name__,
pp_ch_size,
self._map_var_port_class(pt, proc_groups)))
transform_funcs = \
[vp.get_transform_func_fwd()
for vp in pt.get_incoming_virtual_ports()]
pi = VarPortInitializer(
pt.name,
pt.shape,
pt.var.name,
self._get_port_dtype(pt, pm),
pt.__class__.__name__,
pp_ch_size,
self._map_var_port_class(pt, proc_groups),
transform_funcs)
var_ports.append(pi)

# Set implicit VarPorts (created by connecting a RefPort
# directly to a Var) as attribute to ProcessModel
Expand Down
1 change: 1 addition & 0 deletions src/lava/magma/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ class VarPortInitializer:
port_type: str
size: int
port_cls: type
transform_funcs: ty.List[ft.partial] = None
2 changes: 1 addition & 1 deletion src/lava/magma/core/model/py/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _get_var(self):
data_port.send(enum_to_np(var))
elif isinstance(var, np.ndarray):
# FIXME: send a whole vector (also runtime_service.py)
var_iter = np.nditer(var)
var_iter = np.nditer(var, order='C')
num_items: np.integer = np.prod(var.shape)
data_port.send(enum_to_np(num_items))
for value in var_iter:
Expand Down
54 changes: 50 additions & 4 deletions src/lava/magma/core/model/py/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,10 @@ def __init__(self,
csp_recv_port: ty.Optional[CspRecvPort],
process_model: AbstractProcessModel,
shape: ty.Tuple[int, ...] = tuple(),
d_type: type = int):
d_type: type = int,
transform_funcs: ty.Optional[ty.List[ft.partial]] = None):

self._transform_funcs = transform_funcs
self._csp_recv_port = csp_recv_port
self._csp_send_port = csp_send_port
super().__init__(process_model, shape, d_type)
Expand Down Expand Up @@ -513,6 +516,25 @@ def write(
"""
pass

def _transform(self, recv_data: np.array) -> np.array:
"""Applies all transformation function pointers to the input data.
Parameters
----------
recv_data : numpy.ndarray
data received on the port that shall be transformed
Returns
-------
recv_data : numpy.ndarray
received data, transformed by the incoming virtual ports
"""
if self._transform_funcs:
# apply all transformation functions to the received data
for f in reversed(self._transform_funcs):
recv_data = f(recv_data)
return recv_data


class PyRefPortVectorDense(PyRefPort):
"""Python implementation of RefPort for dense vector data."""
Expand All @@ -529,8 +551,10 @@ def read(self) -> np.ndarray:
header = np.ones(self._csp_send_port.shape) * VarPortCmd.GET
self._csp_send_port.send(header)

return self._csp_recv_port.recv()
return self._transform(self._csp_recv_port.recv())

# TODO (MR): self._shape must be set to the correct shape when
# instantiating the Port
return np.zeros(self._shape, self._d_type)

def write(self, data: np.ndarray):
Expand Down Expand Up @@ -660,7 +684,10 @@ def __init__(self,
csp_recv_port: ty.Optional[CspRecvPort],
process_model: AbstractProcessModel,
shape: ty.Tuple[int, ...] = tuple(),
d_type: type = int):
d_type: type = int,
transform_funcs: ty.Optional[ty.List[ft.partial]] = None):

self._transform_funcs = transform_funcs
self._csp_recv_port = csp_recv_port
self._csp_send_port = csp_send_port
self.var_name = var_name
Expand Down Expand Up @@ -692,6 +719,25 @@ def service(self):
"""
pass

def _transform(self, recv_data: np.array) -> np.array:
"""Applies all transformation function pointers to the input data.
Parameters
----------
recv_data : numpy.ndarray
data received on the port that shall be transformed
Returns
-------
recv_data : numpy.ndarray
received data, transformed by the incoming virtual ports
"""
if self._transform_funcs:
# apply all transformation functions to the received data
for f in self._transform_funcs:
recv_data = f(recv_data)
return recv_data


class PyVarPortVectorDense(PyVarPort):
"""Python implementation of VarPort for dense vector data."""
Expand All @@ -712,7 +758,7 @@ def service(self):

# Set the value of the Var with the given data
if enum_equal(cmd, VarPortCmd.SET):
data = self._csp_recv_port.recv()
data = self._transform(self._csp_recv_port.recv())
setattr(self._process_model, self.var_name, data)
elif enum_equal(cmd, VarPortCmd.GET):
data = getattr(self._process_model, self.var_name)
Expand Down
Loading

0 comments on commit 1aa726b

Please sign in to comment.