-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance improvements #87
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,12 +3,12 @@ | |
# See: https://spdx.org/licenses/ | ||
import typing as ty | ||
from queue import Queue, Empty | ||
from threading import Thread | ||
from threading import BoundedSemaphore, Condition, Thread | ||
from time import time | ||
from dataclasses import dataclass | ||
|
||
import numpy as np | ||
from multiprocessing import Pipe, BoundedSemaphore | ||
from multiprocessing import Semaphore | ||
|
||
from lava.magma.compiler.channels.interfaces import ( | ||
Channel, | ||
|
@@ -42,8 +42,8 @@ def __init__(self, name, shm, proto, size, req, ack): | |
shm : SharedMemory | ||
proto : Proto | ||
size : int | ||
req : Pipe | ||
ack : Pipe | ||
req : Semaphore | ||
ack : Semaphore | ||
""" | ||
self._name = name | ||
self._shm = shm | ||
|
@@ -57,6 +57,7 @@ def __init__(self, name, shm, proto, size, req, ack): | |
self._done = False | ||
self._array = [] | ||
self._semaphore = None | ||
self.observer = None | ||
self.thread = None | ||
|
||
@property | ||
|
@@ -96,8 +97,11 @@ def start(self): | |
def _ack_callback(self): | ||
try: | ||
while not self._done: | ||
self._ack.recv_bytes(0) | ||
self._ack.acquire() | ||
not_full = self.probe() | ||
self._semaphore.release() | ||
if self.observer and not not_full: | ||
self.observer() | ||
except EOFError: | ||
pass | ||
|
||
|
@@ -120,7 +124,7 @@ def send(self, data): | |
self._semaphore.acquire() | ||
self._array[self._idx][:] = data[:] | ||
self._idx = (self._idx + 1) % self._size | ||
self._req.send_bytes(bytes(0)) | ||
self._req.release() | ||
|
||
def join(self): | ||
self._done = True | ||
|
@@ -175,8 +179,8 @@ def __init__(self, name, shm, proto, size, req, ack): | |
shm : SharedMemory | ||
proto : Proto | ||
size : int | ||
req : Pipe | ||
ack : Pipe | ||
req : Semaphore | ||
ack : Semaphore | ||
""" | ||
self._name = name | ||
self._shm = shm | ||
|
@@ -190,6 +194,7 @@ def __init__(self, name, shm, proto, size, req, ack): | |
self._done = False | ||
self._array = [] | ||
self._queue = None | ||
self.observer = None | ||
self.thread = None | ||
|
||
@property | ||
|
@@ -229,8 +234,11 @@ def start(self): | |
def _req_callback(self): | ||
try: | ||
while not self._done: | ||
self._req.recv_bytes(0) | ||
self._req.acquire() | ||
not_empty = self.probe() | ||
self._queue.put_nowait(0) | ||
if self.observer and not not_empty: | ||
self.observer() | ||
except EOFError: | ||
pass | ||
|
||
|
@@ -257,14 +265,51 @@ def recv(self): | |
self._queue.get() | ||
result = self._array[self._idx].copy() | ||
self._idx = (self._idx + 1) % self._size | ||
self._ack.send_bytes(bytes(0)) | ||
self._ack.release() | ||
|
||
return result | ||
|
||
def join(self): | ||
self._done = True | ||
|
||
|
||
class CspSelector: | ||
""" | ||
Utility class to allow waiting for multiple channels to become ready | ||
""" | ||
|
||
def __init__(self): | ||
"""Instantiates CspSelector object and class attributes""" | ||
self._cv = Condition() | ||
|
||
def _changed(self): | ||
with self._cv: | ||
self._cv.notify_all() | ||
|
||
def _set_observer(self, channel_actions, observer): | ||
for channel, _ in channel_actions: | ||
channel.observer = observer | ||
|
||
def select( | ||
self, | ||
*args: ty.Tuple[ty.Union[CspSendPort, CspRecvPort], | ||
ty.Callable[[], ty.Any] | ||
] | ||
): | ||
""" | ||
Wait for any channel to become ready, then execute the corresponding | ||
callable and return the result. | ||
""" | ||
with self._cv: | ||
self._set_observer(args, self._changed) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we have to set and unset this all the time? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In an earlier version, all channels notified, and _changed() determined if the source channel is part of the current set being selected on. (This is also the reason for the now unused channel argument, which I didn't bother to remove.) I thought doing it this way was cheaper and cleaner. |
||
while True: | ||
for channel, action in args: | ||
if channel.probe(): | ||
self._set_observer(args, None) | ||
return action() | ||
self._cv.wait() | ||
|
||
|
||
class PyPyChannel(Channel): | ||
"""Helper class to create the set of send and recv port and encapsulate | ||
them inside a common structure. We call this a PyPyChannel""" | ||
|
@@ -290,11 +335,11 @@ def __init__(self, | |
nbytes = np.prod(shape) * np.dtype(dtype).itemsize | ||
smm = message_infrastructure.smm | ||
shm = smm.SharedMemory(int(nbytes * size)) | ||
req = Pipe(duplex=False) | ||
ack = Pipe(duplex=False) | ||
req = Semaphore(0) | ||
ack = Semaphore(0) | ||
proto = Proto(shape=shape, dtype=dtype, nbytes=nbytes) | ||
self._src_port = CspSendPort(src_name, shm, proto, size, req[1], ack[0]) | ||
self._dst_port = CspRecvPort(dst_name, shm, proto, size, req[0], ack[1]) | ||
self._src_port = CspSendPort(src_name, shm, proto, size, req, ack) | ||
self._dst_port = CspRecvPort(dst_name, shm, proto, size, req, ack) | ||
|
||
@property | ||
def src_port(self) -> AbstractCspSendPort: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,13 @@ | |
|
||
import numpy as np | ||
|
||
from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort | ||
from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort,\ | ||
CspSelector | ||
from lava.magma.core.model.model import AbstractProcessModel | ||
from lava.magma.core.model.py.ports import AbstractPyPort, PyVarPort | ||
from lava.magma.runtime.mgmt_token_enums import ( | ||
enum_to_np, | ||
enum_equal, | ||
MGMT_COMMAND, | ||
MGMT_RESPONSE, REQ_TYPE, | ||
) | ||
|
@@ -113,71 +115,76 @@ def run(self): | |
(service_to_process_cmd). After calling the method of a phase of all | ||
ProcessModels the runtime service is informed about completion. The | ||
loop ends when the STOP command is received.""" | ||
selector = CspSelector() | ||
action = 'cmd' | ||
while True: | ||
# Probe if there is a new command from the runtime service | ||
if self.service_to_process_cmd.probe(): | ||
if action == 'cmd': | ||
phase = self.service_to_process_cmd.recv() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, so one key change is that instead of constantly calling probe(), we immediately block on the command channel until we receive anything. |
||
if np.array_equal(phase, MGMT_COMMAND.STOP): | ||
if enum_equal(phase, MGMT_COMMAND.STOP): | ||
self.process_to_service_ack.send(MGMT_RESPONSE.TERMINATED) | ||
self.join() | ||
return | ||
# Spiking phase - increase time step | ||
if np.array_equal(phase, PyLoihiProcessModel.Phase.SPK): | ||
if enum_equal(phase, PyLoihiProcessModel.Phase.SPK): | ||
self.current_ts += 1 | ||
self.run_spk() | ||
self.process_to_service_ack.send(MGMT_RESPONSE.DONE) | ||
# Pre-management phase | ||
elif np.array_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT): | ||
elif enum_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT): | ||
# Enable via guard method | ||
if self.pre_guard(): | ||
self.run_pre_mgmt() | ||
self.process_to_service_ack.send(MGMT_RESPONSE.DONE) | ||
# Handle VarPort requests from RefPorts | ||
if len(self.var_ports) > 0: | ||
self._handle_var_ports() | ||
# Learning phase | ||
elif np.array_equal(phase, PyLoihiProcessModel.Phase.LRN): | ||
elif enum_equal(phase, PyLoihiProcessModel.Phase.LRN): | ||
# Enable via guard method | ||
if self.lrn_guard(): | ||
self.run_lrn() | ||
self.process_to_service_ack.send(MGMT_RESPONSE.DONE) | ||
# Post-management phase | ||
elif np.array_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT): | ||
elif enum_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT): | ||
# Enable via guard method | ||
if self.post_guard(): | ||
self.run_post_mgmt() | ||
self.process_to_service_ack.send(MGMT_RESPONSE.DONE) | ||
# Handle VarPort requests from RefPorts | ||
if len(self.var_ports) > 0: | ||
self._handle_var_ports() | ||
# Host phase - called at the last time step before STOP | ||
elif np.array_equal(phase, PyLoihiProcessModel.Phase.HOST): | ||
# Handle get/set Var requests from runtime service | ||
self._handle_get_set_var() | ||
elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST): | ||
pass | ||
else: | ||
raise ValueError(f"Wrong Phase Info Received : {phase}") | ||
elif action == 'req': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My idea was to treat everything as a command. May it be RUN, PAUSE, STOP, GET, SET, SPK, ... and just handle them all consistently. |
||
# Handle get/set Var requests from runtime service | ||
self._handle_get_set_var() | ||
else: | ||
# Handle VarPort requests from RefPorts | ||
self._handle_var_port(action) | ||
|
||
channel_actions = [(self.service_to_process_cmd, lambda: 'cmd')] | ||
if enum_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT) or \ | ||
enum_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT): | ||
for var_port in self.var_ports: | ||
for csp_port in var_port.csp_ports: | ||
if isinstance(csp_port, CspRecvPort): | ||
channel_actions.append((csp_port, lambda: var_port)) | ||
elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST): | ||
channel_actions.append((self.service_to_process_req, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From a type perspective it is a bit inconsistent to give either a string or a port back as an action. |
||
lambda: 'req')) | ||
action = selector.select(*channel_actions) | ||
|
||
# FIXME: (PP) might not be able to perform get/set during pause | ||
def _handle_get_set_var(self): | ||
"""Handles all get/set Var requests from the runtime service and calls | ||
the corresponding handling methods. The loop ends upon a | ||
new command from runtime service after all get/set Var requests have | ||
been handled.""" | ||
while True: | ||
# Probe if there is a get/set Var request from runtime service | ||
if self.service_to_process_req.probe(): | ||
# Get the type of the request | ||
request = self.service_to_process_req.recv() | ||
if np.array_equal(request, REQ_TYPE.GET): | ||
self._handle_get_var() | ||
elif np.array_equal(request, REQ_TYPE.SET): | ||
self._handle_set_var() | ||
else: | ||
raise RuntimeError(f"Unknown request type {request}") | ||
|
||
# End if another command from runtime service arrives | ||
if self.service_to_process_cmd.probe(): | ||
return | ||
# Get the type of the request | ||
request = self.service_to_process_req.recv() | ||
if enum_equal(request, REQ_TYPE.GET): | ||
self._handle_get_var() | ||
elif enum_equal(request, REQ_TYPE.SET): | ||
self._handle_set_var() | ||
else: | ||
raise RuntimeError(f"Unknown request type {request}") | ||
|
||
def _handle_get_var(self): | ||
"""Handles the get Var command from runtime service.""" | ||
|
@@ -232,16 +239,6 @@ def _handle_set_var(self): | |
else: | ||
raise RuntimeError("Unsupported type") | ||
|
||
# TODO: (PP) use select(..) to service VarPorts instead of a loop | ||
def _handle_var_ports(self): | ||
"""Handles read/write requests on any VarPorts. The loop ends upon a | ||
new command from runtime service after all VarPort service requests have | ||
been handled.""" | ||
while True: | ||
# Loop through read/write requests of each VarPort | ||
for vp in self.var_ports: | ||
vp.service() | ||
|
||
# End if another command from runtime service arrives | ||
if self.service_to_process_cmd.probe(): | ||
return | ||
def _handle_var_port(self, var_port): | ||
"""Handles read/write requests on the given VarPort.""" | ||
var_port.service() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loop is still busy waiting, isn't it? Is this a performance concern? Could this also be improved in performance by using a condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There no busy waiting here, because self._req.recv_bytes() would have blocked.