Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Virtual ports between RefPorts and VarPorts #195

Merged
merged 29 commits into from
Mar 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d3a65e7
permute initial implementation
bamsumit Jan 21, 2022
58d5a14
Tests for permute ports
bamsumit Jan 25, 2022
3891c8c
Process property of virtual ports no longer returns None
mathisrichter Feb 4, 2022
eba5df9
Added initial run-unittest for flatten() from issue #163
mathisrichter Feb 4, 2022
917a497
User-level API for TransposePort with unit tests
mathisrichter Feb 4, 2022
0f33774
Fixed typo
mathisrichter Feb 4, 2022
1537c41
Unit tests for flatten() and concat_with()
mathisrichter Feb 4, 2022
938b7db
Unit tests for virtual ports in Processes that are executed (wip)
mathisrichter Feb 11, 2022
d0eff5a
Preliminary implementation of virtual ports between OutPort and InPor…
mathisrichter Feb 17, 2022
5076ab4
Merge branch 'main' of https://github.com/lava-nc/lava into permute
mathisrichter Feb 17, 2022
846c1b3
Fixing unit tests after merge
mathisrichter Feb 17, 2022
9bde78e
Merge branch 'main' into permute
bamsumit Feb 21, 2022
effd522
Added support for virtual ports between an OutPort and InPort of two …
mathisrichter Feb 22, 2022
a1be80f
Merge branch 'permute' of https://github.com/bamsumit/lava into permute
mathisrichter Feb 22, 2022
c5d0f5f
Clean up, exceptions, and generic unit tests for virtual port topologies
mathisrichter Feb 22, 2022
abf03bd
Fixed linter issues
mathisrichter Feb 22, 2022
c02a44a
Merge branch 'main' into permute
mgkwill Feb 24, 2022
166ec46
Merge branch 'main' into permute
bamsumit Feb 25, 2022
6ca8a51
Raising an exception when executing ConcatPort
mathisrichter Feb 25, 2022
9ebfc6f
Merge branch 'main' of https://github.com/lava-nc/lava into permute
mathisrichter Feb 25, 2022
4873eca
Merge branch 'permute' of https://github.com/bamsumit/lava into permute
mathisrichter Feb 25, 2022
71d1e8b
Merge
mathisrichter Feb 27, 2022
0d5b0bf
Unit tests for virtual ports between OutPorts and InPorts in hierarch…
mathisrichter Feb 27, 2022
6a15666
RefPort writing to an explicit VarPort via a virtual port.
mathisrichter Feb 28, 2022
9ceddb1
RefPort reading from an explicit VarPort via a virtual port.
mathisrichter Feb 28, 2022
fbf1f54
Fixed linter error
mathisrichter Mar 1, 2022
ae4debf
Unit tests for virtual ports between RefPorts and VarPorts in hierarc…
mathisrichter Mar 1, 2022
83fd8d9
Fixed a docstring
mathisrichter Mar 1, 2022
eaf55ca
Added docstrings to methods get_transform_func_fwd/bwd.
mathisrichter Mar 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/lava/magma/compiler/builders/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,8 @@ def build(self):
csp_send = csp_ports[0] if isinstance(
csp_ports[0], CspSendPort) else csp_ports[1]

port = port_cls(csp_send, csp_recv, pm, p.shape, lt.d_type)
port = port_cls(csp_send, csp_recv, pm, p.shape, lt.d_type,
p.transform_funcs)

# Create dynamic RefPort attribute on ProcModel
setattr(pm, name, port)
Expand All @@ -422,7 +423,8 @@ def build(self):
csp_send = csp_ports[0] if isinstance(
csp_ports[0], CspSendPort) else csp_ports[1]
port = port_cls(
p.var_name, csp_send, csp_recv, pm, p.shape, p.d_type)
p.var_name, csp_send, csp_recv, pm, p.shape, p.d_type,
p.transform_funcs)

# Create dynamic VarPort attribute on ProcModel
setattr(pm, name, port)
Expand Down
49 changes: 31 additions & 18 deletions src/lava/magma/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,40 +352,53 @@ def _compile_proc_models(
for pt in (list(p.in_ports) + list(p.out_ports)):
# For all InPorts that receive input from
# virtual ports...
transform_funcs = None
transform_funcs = []
if isinstance(pt, InPort):
# ... extract a function pointer to the
# transformation function of each virtual port.
transform_funcs = \
[vp.get_transform_func()
[vp.get_transform_func_fwd()
for vp in pt.get_incoming_virtual_ports()]

pi = PortInitializer(pt.name,
pt.shape,
self._get_port_dtype(pt, pm),
pt.__class__.__name__,
pp_ch_size,
transform_funcs)
ports.append(pi)

# Create RefPort (also use PortInitializers)
ref_ports = list(p.ref_ports)
ref_ports = [
PortInitializer(pt.name,
pt.shape,
self._get_port_dtype(pt, pm),
pt.__class__.__name__,
pp_ch_size) for pt in ref_ports]
ref_ports = []
for pt in list(p.ref_ports):
transform_funcs = \
[vp.get_transform_func_bwd()
for vp in pt.get_outgoing_virtual_ports()]

pi = PortInitializer(pt.name,
pt.shape,
self._get_port_dtype(pt, pm),
pt.__class__.__name__,
pp_ch_size,
transform_funcs)
ref_ports.append(pi)

# Create VarPortInitializers (contain also the Var name)
var_ports = []
for pt in list(p.var_ports):
var_ports.append(
VarPortInitializer(
pt.name,
pt.shape,
pt.var.name,
self._get_port_dtype(pt, pm),
pt.__class__.__name__,
pp_ch_size,
self._map_var_port_class(pt, proc_groups)))
transform_funcs = \
[vp.get_transform_func_fwd()
for vp in pt.get_incoming_virtual_ports()]
pi = VarPortInitializer(
pt.name,
pt.shape,
pt.var.name,
self._get_port_dtype(pt, pm),
pt.__class__.__name__,
pp_ch_size,
self._map_var_port_class(pt, proc_groups),
transform_funcs)
var_ports.append(pi)

# Set implicit VarPorts (created by connecting a RefPort
# directly to a Var) as attribute to ProcessModel
Expand Down
1 change: 1 addition & 0 deletions src/lava/magma/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ class VarPortInitializer:
port_type: str
size: int
port_cls: type
transform_funcs: ty.List[ft.partial] = None
2 changes: 1 addition & 1 deletion src/lava/magma/core/model/py/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _get_var(self):
data_port.send(enum_to_np(var))
elif isinstance(var, np.ndarray):
# FIXME: send a whole vector (also runtime_service.py)
var_iter = np.nditer(var)
var_iter = np.nditer(var, order='C')
num_items: np.integer = np.prod(var.shape)
data_port.send(enum_to_np(num_items))
for value in var_iter:
Expand Down
54 changes: 50 additions & 4 deletions src/lava/magma/core/model/py/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,10 @@ def __init__(self,
csp_recv_port: ty.Optional[CspRecvPort],
process_model: AbstractProcessModel,
shape: ty.Tuple[int, ...] = tuple(),
d_type: type = int):
d_type: type = int,
transform_funcs: ty.Optional[ty.List[ft.partial]] = None):

self._transform_funcs = transform_funcs
self._csp_recv_port = csp_recv_port
self._csp_send_port = csp_send_port
super().__init__(process_model, shape, d_type)
Expand Down Expand Up @@ -513,6 +516,25 @@ def write(
"""
pass

def _transform(self, recv_data: np.array) -> np.array:
"""Applies all transformation function pointers to the input data.

Parameters
----------
recv_data : numpy.ndarray
data received on the port that shall be transformed

Returns
-------
recv_data : numpy.ndarray
received data, transformed by the incoming virtual ports
"""
if self._transform_funcs:
# apply all transformation functions to the received data
for f in reversed(self._transform_funcs):
recv_data = f(recv_data)
return recv_data


class PyRefPortVectorDense(PyRefPort):
"""Python implementation of RefPort for dense vector data."""
Expand All @@ -529,8 +551,10 @@ def read(self) -> np.ndarray:
header = np.ones(self._csp_send_port.shape) * VarPortCmd.GET
self._csp_send_port.send(header)

return self._csp_recv_port.recv()
return self._transform(self._csp_recv_port.recv())

# TODO (MR): self._shape must be set to the correct shape when
# instantiating the Port
return np.zeros(self._shape, self._d_type)

def write(self, data: np.ndarray):
Expand Down Expand Up @@ -660,7 +684,10 @@ def __init__(self,
csp_recv_port: ty.Optional[CspRecvPort],
process_model: AbstractProcessModel,
shape: ty.Tuple[int, ...] = tuple(),
d_type: type = int):
d_type: type = int,
transform_funcs: ty.Optional[ty.List[ft.partial]] = None):

self._transform_funcs = transform_funcs
self._csp_recv_port = csp_recv_port
self._csp_send_port = csp_send_port
self.var_name = var_name
Expand Down Expand Up @@ -692,6 +719,25 @@ def service(self):
"""
pass

def _transform(self, recv_data: np.array) -> np.array:
"""Applies all transformation function pointers to the input data.

Parameters
----------
recv_data : numpy.ndarray
data received on the port that shall be transformed

Returns
-------
recv_data : numpy.ndarray
received data, transformed by the incoming virtual ports
"""
if self._transform_funcs:
# apply all transformation functions to the received data
for f in self._transform_funcs:
recv_data = f(recv_data)
return recv_data


class PyVarPortVectorDense(PyVarPort):
"""Python implementation of VarPort for dense vector data."""
Expand All @@ -712,7 +758,7 @@ def service(self):

# Set the value of the Var with the given data
if enum_equal(cmd, VarPortCmd.SET):
data = self._csp_recv_port.recv()
data = self._transform(self._csp_recv_port.recv())
setattr(self._process_model, self.var_name, data)
elif enum_equal(cmd, VarPortCmd.GET):
data = getattr(self._process_model, self.var_name)
Expand Down
Loading