From 38e14d5b53c075923760faeff19781a3ae9ea4da Mon Sep 17 00:00:00 2001 From: harryliu-intel <87097499+harryliu-intel@users.noreply.github.com> Date: Mon, 22 Nov 2021 09:24:57 -0800 Subject: [PATCH] Performance improvements (#87) * Use specialized np.array_equal for performance. * Use select instead of probe with busy waiting. * Use multiprocessing semaphore instead of pipe. --- .../magma/compiler/channels/pypychannel.py | 73 +++++++++-- src/lava/magma/core/model/py/model.py | 87 ++++++------- src/lava/magma/core/model/py/ports.py | 6 +- src/lava/magma/runtime/mgmt_token_enums.py | 11 ++ src/lava/magma/runtime/runtime.py | 10 +- src/lava/magma/runtime/runtime_service.py | 121 +++++++++--------- 6 files changed, 180 insertions(+), 128 deletions(-) diff --git a/src/lava/magma/compiler/channels/pypychannel.py b/src/lava/magma/compiler/channels/pypychannel.py index f93287c2b..0450efd5a 100644 --- a/src/lava/magma/compiler/channels/pypychannel.py +++ b/src/lava/magma/compiler/channels/pypychannel.py @@ -3,12 +3,12 @@ # See: https://spdx.org/licenses/ import typing as ty from queue import Queue, Empty -from threading import Thread +from threading import BoundedSemaphore, Condition, Thread from time import time from dataclasses import dataclass import numpy as np -from multiprocessing import Pipe, BoundedSemaphore +from multiprocessing import Semaphore from lava.magma.compiler.channels.interfaces import ( Channel, @@ -42,8 +42,8 @@ def __init__(self, name, shm, proto, size, req, ack): shm : SharedMemory proto : Proto size : int - req : Pipe - ack : Pipe + req : Semaphore + ack : Semaphore """ self._name = name self._shm = shm @@ -57,6 +57,7 @@ def __init__(self, name, shm, proto, size, req, ack): self._done = False self._array = [] self._semaphore = None + self.observer = None self.thread = None @property @@ -96,8 +97,11 @@ def start(self): def _ack_callback(self): try: while not self._done: - self._ack.recv_bytes(0) + self._ack.acquire() + not_full = self.probe() self._semaphore.release() + if self.observer and not not_full: + self.observer() except EOFError: pass @@ -120,7 +124,7 @@ def send(self, data): self._semaphore.acquire() self._array[self._idx][:] = data[:] self._idx = (self._idx + 1) % self._size - self._req.send_bytes(bytes(0)) + self._req.release() def join(self): self._done = True @@ -175,8 +179,8 @@ def __init__(self, name, shm, proto, size, req, ack): shm : SharedMemory proto : Proto size : int - req : Pipe - ack : Pipe + req : Semaphore + ack : Semaphore """ self._name = name self._shm = shm @@ -190,6 +194,7 @@ def __init__(self, name, shm, proto, size, req, ack): self._done = False self._array = [] self._queue = None + self.observer = None self.thread = None @property @@ -229,8 +234,11 @@ def start(self): def _req_callback(self): try: while not self._done: - self._req.recv_bytes(0) + self._req.acquire() + not_empty = self.probe() self._queue.put_nowait(0) + if self.observer and not not_empty: + self.observer() except EOFError: pass @@ -257,7 +265,7 @@ def recv(self): self._queue.get() result = self._array[self._idx].copy() self._idx = (self._idx + 1) % self._size - self._ack.send_bytes(bytes(0)) + self._ack.release() return result @@ -265,6 +273,43 @@ def join(self): self._done = True +class CspSelector: + """ + Utility class to allow waiting for multiple channels to become ready + """ + + def __init__(self): + """Instantiates CspSelector object and class attributes""" + self._cv = Condition() + + def _changed(self): + with self._cv: + self._cv.notify_all() + + def _set_observer(self, channel_actions, observer): + for channel, _ in channel_actions: + channel.observer = observer + + def select( + self, + *args: ty.Tuple[ty.Union[CspSendPort, CspRecvPort], + ty.Callable[[], ty.Any] + ] + ): + """ + Wait for any channel to become ready, then execute the corresponding + callable and return the result. + """ + with self._cv: + self._set_observer(args, self._changed) + while True: + for channel, action in args: + if channel.probe(): + self._set_observer(args, None) + return action() + self._cv.wait() + + class PyPyChannel(Channel): """Helper class to create the set of send and recv port and encapsulate them inside a common structure. We call this a PyPyChannel""" @@ -290,11 +335,11 @@ def __init__(self, nbytes = np.prod(shape) * np.dtype(dtype).itemsize smm = message_infrastructure.smm shm = smm.SharedMemory(int(nbytes * size)) - req = Pipe(duplex=False) - ack = Pipe(duplex=False) + req = Semaphore(0) + ack = Semaphore(0) proto = Proto(shape=shape, dtype=dtype, nbytes=nbytes) - self._src_port = CspSendPort(src_name, shm, proto, size, req[1], ack[0]) - self._dst_port = CspRecvPort(dst_name, shm, proto, size, req[0], ack[1]) + self._src_port = CspSendPort(src_name, shm, proto, size, req, ack) + self._dst_port = CspRecvPort(dst_name, shm, proto, size, req, ack) @property def src_port(self) -> AbstractCspSendPort: diff --git a/src/lava/magma/core/model/py/model.py b/src/lava/magma/core/model/py/model.py index 49d58ff1f..fab00722c 100644 --- a/src/lava/magma/core/model/py/model.py +++ b/src/lava/magma/core/model/py/model.py @@ -6,11 +6,13 @@ import numpy as np -from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort +from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort,\ + CspSelector from lava.magma.core.model.model import AbstractProcessModel from lava.magma.core.model.py.ports import AbstractPyPort, PyVarPort from lava.magma.runtime.mgmt_token_enums import ( enum_to_np, + enum_equal, MGMT_COMMAND, MGMT_RESPONSE, REQ_TYPE, ) @@ -113,49 +115,61 @@ def run(self): (service_to_process_cmd). After calling the method of a phase of all ProcessModels the runtime service is informed about completion. The loop ends when the STOP command is received.""" + selector = CspSelector() + action = 'cmd' while True: - # Probe if there is a new command from the runtime service - if self.service_to_process_cmd.probe(): + if action == 'cmd': phase = self.service_to_process_cmd.recv() - if np.array_equal(phase, MGMT_COMMAND.STOP): + if enum_equal(phase, MGMT_COMMAND.STOP): self.process_to_service_ack.send(MGMT_RESPONSE.TERMINATED) self.join() return # Spiking phase - increase time step - if np.array_equal(phase, PyLoihiProcessModel.Phase.SPK): + if enum_equal(phase, PyLoihiProcessModel.Phase.SPK): self.current_ts += 1 self.run_spk() self.process_to_service_ack.send(MGMT_RESPONSE.DONE) # Pre-management phase - elif np.array_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT): + elif enum_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT): # Enable via guard method if self.pre_guard(): self.run_pre_mgmt() self.process_to_service_ack.send(MGMT_RESPONSE.DONE) - # Handle VarPort requests from RefPorts - if len(self.var_ports) > 0: - self._handle_var_ports() # Learning phase - elif np.array_equal(phase, PyLoihiProcessModel.Phase.LRN): + elif enum_equal(phase, PyLoihiProcessModel.Phase.LRN): # Enable via guard method if self.lrn_guard(): self.run_lrn() self.process_to_service_ack.send(MGMT_RESPONSE.DONE) # Post-management phase - elif np.array_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT): + elif enum_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT): # Enable via guard method if self.post_guard(): self.run_post_mgmt() self.process_to_service_ack.send(MGMT_RESPONSE.DONE) - # Handle VarPort requests from RefPorts - if len(self.var_ports) > 0: - self._handle_var_ports() # Host phase - called at the last time step before STOP - elif np.array_equal(phase, PyLoihiProcessModel.Phase.HOST): - # Handle get/set Var requests from runtime service - self._handle_get_set_var() + elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST): + pass else: raise ValueError(f"Wrong Phase Info Received : {phase}") + elif action == 'req': + # Handle get/set Var requests from runtime service + self._handle_get_set_var() + else: + # Handle VarPort requests from RefPorts + self._handle_var_port(action) + + channel_actions = [(self.service_to_process_cmd, lambda: 'cmd')] + if enum_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT) or \ + enum_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT): + for var_port in self.var_ports: + for csp_port in var_port.csp_ports: + if isinstance(csp_port, CspRecvPort): + channel_actions.append((csp_port, lambda: var_port)) + elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST): + channel_actions.append((self.service_to_process_req, + lambda: 'req')) + action = selector.select(*channel_actions) # FIXME: (PP) might not be able to perform get/set during pause def _handle_get_set_var(self): @@ -163,21 +177,14 @@ def _handle_get_set_var(self): the corresponding handling methods. The loop ends upon a new command from runtime service after all get/set Var requests have been handled.""" - while True: - # Probe if there is a get/set Var request from runtime service - if self.service_to_process_req.probe(): - # Get the type of the request - request = self.service_to_process_req.recv() - if np.array_equal(request, REQ_TYPE.GET): - self._handle_get_var() - elif np.array_equal(request, REQ_TYPE.SET): - self._handle_set_var() - else: - raise RuntimeError(f"Unknown request type {request}") - - # End if another command from runtime service arrives - if self.service_to_process_cmd.probe(): - return + # Get the type of the request + request = self.service_to_process_req.recv() + if enum_equal(request, REQ_TYPE.GET): + self._handle_get_var() + elif enum_equal(request, REQ_TYPE.SET): + self._handle_set_var() + else: + raise RuntimeError(f"Unknown request type {request}") def _handle_get_var(self): """Handles the get Var command from runtime service.""" @@ -232,16 +239,6 @@ def _handle_set_var(self): else: raise RuntimeError("Unsupported type") - # TODO: (PP) use select(..) to service VarPorts instead of a loop - def _handle_var_ports(self): - """Handles read/write requests on any VarPorts. The loop ends upon a - new command from runtime service after all VarPort service requests have - been handled.""" - while True: - # Loop through read/write requests of each VarPort - for vp in self.var_ports: - vp.service() - - # End if another command from runtime service arrives - if self.service_to_process_cmd.probe(): - return + def _handle_var_port(self, var_port): + """Handles read/write requests on the given VarPort.""" + var_port.service() diff --git a/src/lava/magma/core/model/py/ports.py b/src/lava/magma/core/model/py/ports.py index 992d6a73e..6031c2c8d 100644 --- a/src/lava/magma/core/model/py/ports.py +++ b/src/lava/magma/core/model/py/ports.py @@ -9,7 +9,7 @@ from lava.magma.compiler.channels.interfaces import AbstractCspPort from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort from lava.magma.core.model.interfaces import AbstractPortImplementation -from lava.magma.runtime.mgmt_token_enums import enum_to_np +from lava.magma.runtime.mgmt_token_enums import enum_to_np, enum_equal class AbstractPyPort(AbstractPortImplementation): @@ -301,10 +301,10 @@ def service(self): cmd = enum_to_np(self._csp_recv_port.recv()[0]) # Set the value of the Var with the given data - if np.array_equal(cmd, VarPortCmd.SET): + if enum_equal(cmd, VarPortCmd.SET): data = self._csp_recv_port.recv() setattr(self._process_model, self.var_name, data) - elif np.array_equal(cmd, VarPortCmd.GET): + elif enum_equal(cmd, VarPortCmd.GET): data = getattr(self._process_model, self.var_name) self._csp_send_port.send(data) else: diff --git a/src/lava/magma/runtime/mgmt_token_enums.py b/src/lava/magma/runtime/mgmt_token_enums.py index d3209f9c8..e56734635 100644 --- a/src/lava/magma/runtime/mgmt_token_enums.py +++ b/src/lava/magma/runtime/mgmt_token_enums.py @@ -20,6 +20,17 @@ def enum_to_np(value: ty.Union[int, float], return np.array([value], dtype=d_type) +def enum_equal(a: np.array, b: np.array) -> bool: + """ + Helper function to compare two np arrays created by enum_to_np. + + :param a: 1-D array created by enum_to_np + :param b: 1-D array created by enum_to_np + :return: True if the two arrays are equal + """ + return a[0] == b[0] + + class MGMT_COMMAND: """ Signifies the Mgmt Command being sent between two actors. These may be diff --git a/src/lava/magma/runtime/runtime.py b/src/lava/magma/runtime/runtime.py index e7f4c6b24..194096b6a 100644 --- a/src/lava/magma/runtime/runtime.py +++ b/src/lava/magma/runtime/runtime.py @@ -14,8 +14,8 @@ import MessageInfrastructureInterface from lava.magma.runtime.message_infrastructure.factory import \ MessageInfrastructureFactory -from lava.magma.runtime.mgmt_token_enums import MGMT_COMMAND, MGMT_RESPONSE, \ - enum_to_np, REQ_TYPE +from lava.magma.runtime.mgmt_token_enums import enum_to_np, enum_equal, \ + MGMT_COMMAND, MGMT_RESPONSE, REQ_TYPE from lava.magma.runtime.runtime_service import AsyncPyRuntimeService if ty.TYPE_CHECKING: @@ -227,7 +227,7 @@ def _run(self, run_condition): if run_condition.blocking: for recv_port in self.service_to_runtime_ack: data = recv_port.recv() - if not np.array_equal(data, MGMT_RESPONSE.DONE): + if not enum_equal(data, MGMT_RESPONSE.DONE): raise RuntimeError(f"Runtime Received {data}") if run_condition.blocking: self.current_ts += self.num_steps @@ -244,7 +244,7 @@ def wait(self): if self._is_running: for recv_port in self.service_to_runtime_ack: data = recv_port.recv() - if not np.array_equal(data, MGMT_RESPONSE.DONE): + if not enum_equal(data, MGMT_RESPONSE.DONE): raise RuntimeError(f"Runtime Received {data}") self.current_ts += self.num_steps self._is_running = False @@ -260,7 +260,7 @@ def stop(self): send_port.send(MGMT_COMMAND.STOP) for recv_port in self.service_to_runtime_ack: data = recv_port.recv() - if not np.array_equal(data, MGMT_RESPONSE.TERMINATED): + if not enum_equal(data, MGMT_RESPONSE.TERMINATED): raise RuntimeError(f"Runtime Received {data}") self.join() self._is_running = False diff --git a/src/lava/magma/runtime/runtime_service.py b/src/lava/magma/runtime/runtime_service.py index 9e9053a1d..dd7d2f924 100644 --- a/src/lava/magma/runtime/runtime_service.py +++ b/src/lava/magma/runtime/runtime_service.py @@ -6,12 +6,15 @@ import numpy as np -from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort +from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort,\ + CspSelector from lava.magma.core.sync.protocol import AbstractSyncProtocol from lava.magma.runtime.mgmt_token_enums import ( + enum_to_np, + enum_equal, MGMT_RESPONSE, MGMT_COMMAND, - enum_to_np, REQ_TYPE, + REQ_TYPE, ) @@ -130,9 +133,8 @@ def _get_pm_resp(self) -> ty.Iterable[MGMT_RESPONSE]: counter = 0 while counter < num_responses_expected: ptos_recv_port = self.process_to_service_ack[counter] - if ptos_recv_port.probe(): - rcv_msgs.append(ptos_recv_port.recv()) - counter += 1 + rcv_msgs.append(ptos_recv_port.recv()) + counter += 1 return rcv_msgs def _relay_to_runtime_data_given_model_id(self, model_id: int): @@ -178,28 +180,31 @@ def run(self): In this case iterate through the phases of the Loihi protocol until the last time step is reached. The runtime is informed after the last time step. The loop ends when receiving the STOP command from the runtime.""" + selector = CspSelector() phase = LoihiPyRuntimeService.Phase.HOST while True: # Probe if there is a new command from the runtime - if self.runtime_to_service_cmd.probe(): + cmd = selector.select((self.runtime_to_service_cmd, lambda: True), + (self.runtime_to_service_req, lambda: False)) + if cmd: command = self.runtime_to_service_cmd.recv() - if np.array_equal(command, MGMT_COMMAND.STOP): + if enum_equal(command, MGMT_COMMAND.STOP): # Inform all ProcessModels about the STOP command self._send_pm_cmd(command) rsps = self._get_pm_resp() for rsp in rsps: - if not np.array_equal(rsp, MGMT_RESPONSE.TERMINATED): + if not enum_equal(rsp, MGMT_RESPONSE.TERMINATED): raise ValueError(f"Wrong Response Received : {rsp}") # Inform the runtime about successful termination self.service_to_runtime_ack.send(MGMT_RESPONSE.TERMINATED) self.join() return - elif np.array_equal(command, MGMT_COMMAND.PAUSE): + elif enum_equal(command, MGMT_COMMAND.PAUSE): # Inform all ProcessModels about the PAUSE command self._send_pm_cmd(command) rsps = self._get_pm_resp() for rsp in rsps: - if not np.array_equal(rsp, MGMT_RESPONSE.PAUSED): + if not enum_equal(rsp, MGMT_RESPONSE.PAUSED): raise ValueError(f"Wrong Response Received : {rsp}") # Inform the runtime about successful pausing self.service_to_runtime_ack.send(MGMT_RESPONSE.PAUSED) @@ -211,75 +216,69 @@ def run(self): phase = LoihiPyRuntimeService.Phase.HOST while True: # Check if it is the last time step - is_last_ts = np.array_equal(enum_to_np(curr_time_step), - command) + is_last_ts = enum_equal(enum_to_np(curr_time_step), + command) # Advance to the next phase phase = self._next_phase(phase, is_last_ts) # Increase time step if spiking phase - if np.array_equal(phase, - LoihiPyRuntimeService.Phase.SPK): + if enum_equal(phase, LoihiPyRuntimeService.Phase.SPK): curr_time_step += 1 # Inform ProcessModels about current phase self._send_pm_cmd(phase) # ProcessModels respond with DONE if not HOST phase - if not np.array_equal( + if not enum_equal( phase, LoihiPyRuntimeService.Phase.HOST): rsps = self._get_pm_resp() for rsp in rsps: - if not np.array_equal(rsp, MGMT_RESPONSE.DONE): + if not enum_equal(rsp, MGMT_RESPONSE.DONE): raise ValueError( f"Wrong Response Received : {rsp}") # If HOST phase (last time step ended) break the loop - if np.array_equal( + if enum_equal( phase, LoihiPyRuntimeService.Phase.HOST): break # Inform the runtime that last time step was reached self.service_to_runtime_ack.send(MGMT_RESPONSE.DONE) - - # Handle get/set Var - self._handle_get_set(phase) + else: + # Handle get/set Var + self._handle_get_set(phase) def _handle_get_set(self, phase): - if np.array_equal(phase, LoihiPyRuntimeService.Phase.HOST): - while True: - if self.runtime_to_service_req.probe(): - request = self.runtime_to_service_req.recv() - if np.array_equal(request, REQ_TYPE.GET): - requests: ty.List[np.ndarray] = [request] - # recv model_id - model_id: int = \ - self.runtime_to_service_req.recv()[ - 0].item() - # recv var_id - requests.append( - self.runtime_to_service_req.recv()) - self._send_pm_req_given_model_id(model_id, - *requests) - - self._relay_to_runtime_data_given_model_id( - model_id) - elif np.array_equal(request, REQ_TYPE.SET): - requests: ty.List[np.ndarray] = [request] - # recv model_id - model_id: int = \ - self.runtime_to_service_req.recv()[ - 0].item() - # recv var_id - requests.append( - self.runtime_to_service_req.recv()) - self._send_pm_req_given_model_id(model_id, - *requests) - - self._relay_to_pm_data_given_model_id( - model_id) - else: - raise RuntimeError( - f"Unknown request {request}") - - if self.runtime_to_service_cmd.probe(): - return + if enum_equal(phase, LoihiPyRuntimeService.Phase.HOST): + request = self.runtime_to_service_req.recv() + if enum_equal(request, REQ_TYPE.GET): + requests: ty.List[np.ndarray] = [request] + # recv model_id + model_id: int = \ + self.runtime_to_service_req.recv()[ + 0].item() + # recv var_id + requests.append( + self.runtime_to_service_req.recv()) + self._send_pm_req_given_model_id(model_id, + *requests) + + self._relay_to_runtime_data_given_model_id( + model_id) + elif enum_equal(request, REQ_TYPE.SET): + requests: ty.List[np.ndarray] = [request] + # recv model_id + model_id: int = \ + self.runtime_to_service_req.recv()[ + 0].item() + # recv var_id + requests.append( + self.runtime_to_service_req.recv()) + self._send_pm_req_given_model_id(model_id, + *requests) + + self._relay_to_pm_data_given_model_id( + model_id) + else: + raise RuntimeError( + f"Unknown request {request}") class LoihiCRuntimeService(AbstractRuntimeService): @@ -304,11 +303,11 @@ def _get_pm_resp(self) -> ty.Iterable[MGMT_RESPONSE]: def run(self): while True: command = self.runtime_to_service_cmd.recv() - if np.array_equal(command, MGMT_COMMAND.STOP): + if enum_equal(command, MGMT_COMMAND.STOP): self._send_pm_cmd(command) rsps = self._get_pm_resp() for rsp in rsps: - if not np.array_equal(rsp, MGMT_RESPONSE.TERMINATED): + if not enum_equal(rsp, MGMT_RESPONSE.TERMINATED): raise ValueError(f"Wrong Response Received : {rsp}") self.service_to_runtime_ack.send(MGMT_RESPONSE.TERMINATED) self.join() @@ -317,6 +316,6 @@ def run(self): self._send_pm_cmd(MGMT_COMMAND.RUN) rsps = self._get_pm_resp() for rsp in rsps: - if not np.array_equal(rsp, MGMT_RESPONSE.DONE): + if not enum_equal(rsp, MGMT_RESPONSE.DONE): raise ValueError(f"Wrong Response Received : {rsp}") self.service_to_runtime_ack.send(MGMT_RESPONSE.DONE)