Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed automatic append of py_ports and var_ports #669

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/lava/magma/compiler/channels/pypychannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ def recv(self):
result = self._array[self._idx].copy()
self._idx = (self._idx + 1) % self._size
self._ack.release()

return result

def join(self):
Expand Down
5 changes: 2 additions & 3 deletions src/lava/magma/core/model/py/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ def __setattr__(self, key: str, value: ty.Any):

"""
self.__dict__[key] = value
if isinstance(value, AbstractPyPort):
if isinstance(value, AbstractPyPort) and value not in self.py_ports:
self.py_ports.append(value)
# Store all VarPorts for efficient RefPort -> VarPort handling
if isinstance(value, PyVarPort):
if isinstance(value, PyVarPort) and value not in self.var_ports:
self.var_ports.append(value)

def start(self):
Expand Down
19 changes: 12 additions & 7 deletions src/lava/magma/core/model/py/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,13 +699,16 @@ def read(self) -> np.ndarray:
The value of the referenced Var.
"""
if self._csp_send_port and self._csp_recv_port:
header = np.ones(self._csp_send_port.shape) * VarPortCmd.GET
self._csp_send_port.send(header)

if not hasattr(self, 'get_header'):
self.get_header = (np.ones(self._csp_send_port.shape)
* VarPortCmd.GET)
self._csp_send_port.send(self.get_header)
return self._transformer.transform(self._csp_recv_port.recv(),
self._csp_recv_port)

return np.zeros(self._shape, self._d_type)
else:
if not hasattr(self, 'get_zeros'):
self.get_zeros = np.zeros(self._shape, self._d_type)
return self.get_zeros

def write(self, data: np.ndarray):
"""Abstract method to write data to a VarPort to set the value of the
Expand All @@ -717,8 +720,10 @@ def write(self, data: np.ndarray):
The data to send via _csp_send_port.
"""
if self._csp_send_port:
header = np.ones(self._csp_send_port.shape) * VarPortCmd.SET
self._csp_send_port.send(header)
if not hasattr(self, 'set_header'):
self.set_header = (np.ones(self._csp_send_port.shape)
* VarPortCmd.SET)
self._csp_send_port.send(self.set_header)
self._csp_send_port.send(data)


Expand Down