From 66eae9b1289d2d682bbe53f439b6ffc7d3b5f409 Mon Sep 17 00:00:00 2001 From: "Liu, Ruokun" Date: Thu, 18 Nov 2021 17:11:19 -0800 Subject: [PATCH 1/3] Use specialized np.array_equal for performance. --- src/lava/magma/core/model/py/model.py | 17 +++++----- src/lava/magma/core/model/py/ports.py | 6 ++-- src/lava/magma/runtime/mgmt_token_enums.py | 11 +++++++ src/lava/magma/runtime/runtime.py | 10 +++--- src/lava/magma/runtime/runtime_service.py | 37 +++++++++++----------- 5 files changed, 47 insertions(+), 34 deletions(-) diff --git a/src/lava/magma/core/model/py/model.py b/src/lava/magma/core/model/py/model.py index 49d58ff1f..d6d40a7d9 100644 --- a/src/lava/magma/core/model/py/model.py +++ b/src/lava/magma/core/model/py/model.py @@ -11,6 +11,7 @@ from lava.magma.core.model.py.ports import AbstractPyPort, PyVarPort from lava.magma.runtime.mgmt_token_enums import ( enum_to_np, + enum_equal, MGMT_COMMAND, MGMT_RESPONSE, REQ_TYPE, ) @@ -117,17 +118,17 @@ def run(self): # Probe if there is a new command from the runtime service if self.service_to_process_cmd.probe(): phase = self.service_to_process_cmd.recv() - if np.array_equal(phase, MGMT_COMMAND.STOP): + if enum_equal(phase, MGMT_COMMAND.STOP): self.process_to_service_ack.send(MGMT_RESPONSE.TERMINATED) self.join() return # Spiking phase - increase time step - if np.array_equal(phase, PyLoihiProcessModel.Phase.SPK): + if enum_equal(phase, PyLoihiProcessModel.Phase.SPK): self.current_ts += 1 self.run_spk() self.process_to_service_ack.send(MGMT_RESPONSE.DONE) # Pre-management phase - elif np.array_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT): + elif enum_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT): # Enable via guard method if self.pre_guard(): self.run_pre_mgmt() @@ -136,13 +137,13 @@ def run(self): if len(self.var_ports) > 0: self._handle_var_ports() # Learning phase - elif np.array_equal(phase, PyLoihiProcessModel.Phase.LRN): + elif enum_equal(phase, PyLoihiProcessModel.Phase.LRN): # Enable via guard method if self.lrn_guard(): self.run_lrn() self.process_to_service_ack.send(MGMT_RESPONSE.DONE) # Post-management phase - elif np.array_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT): + elif enum_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT): # Enable via guard method if self.post_guard(): self.run_post_mgmt() @@ -151,7 +152,7 @@ def run(self): if len(self.var_ports) > 0: self._handle_var_ports() # Host phase - called at the last time step before STOP - elif np.array_equal(phase, PyLoihiProcessModel.Phase.HOST): + elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST): # Handle get/set Var requests from runtime service self._handle_get_set_var() else: @@ -168,9 +169,9 @@ def _handle_get_set_var(self): if self.service_to_process_req.probe(): # Get the type of the request request = self.service_to_process_req.recv() - if np.array_equal(request, REQ_TYPE.GET): + if enum_equal(request, REQ_TYPE.GET): self._handle_get_var() - elif np.array_equal(request, REQ_TYPE.SET): + elif enum_equal(request, REQ_TYPE.SET): self._handle_set_var() else: raise RuntimeError(f"Unknown request type {request}") diff --git a/src/lava/magma/core/model/py/ports.py b/src/lava/magma/core/model/py/ports.py index 992d6a73e..6031c2c8d 100644 --- a/src/lava/magma/core/model/py/ports.py +++ b/src/lava/magma/core/model/py/ports.py @@ -9,7 +9,7 @@ from lava.magma.compiler.channels.interfaces import AbstractCspPort from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort from lava.magma.core.model.interfaces import AbstractPortImplementation -from lava.magma.runtime.mgmt_token_enums import enum_to_np +from lava.magma.runtime.mgmt_token_enums import enum_to_np, enum_equal class AbstractPyPort(AbstractPortImplementation): @@ -301,10 +301,10 @@ def service(self): cmd = enum_to_np(self._csp_recv_port.recv()[0]) # Set the value of the Var with the given data - if np.array_equal(cmd, VarPortCmd.SET): + if enum_equal(cmd, VarPortCmd.SET): data = self._csp_recv_port.recv() setattr(self._process_model, self.var_name, data) - elif np.array_equal(cmd, VarPortCmd.GET): + elif enum_equal(cmd, VarPortCmd.GET): data = getattr(self._process_model, self.var_name) self._csp_send_port.send(data) else: diff --git a/src/lava/magma/runtime/mgmt_token_enums.py b/src/lava/magma/runtime/mgmt_token_enums.py index d3209f9c8..e56734635 100644 --- a/src/lava/magma/runtime/mgmt_token_enums.py +++ b/src/lava/magma/runtime/mgmt_token_enums.py @@ -20,6 +20,17 @@ def enum_to_np(value: ty.Union[int, float], return np.array([value], dtype=d_type) +def enum_equal(a: np.array, b: np.array) -> bool: + """ + Helper function to compare two np arrays created by enum_to_np. + + :param a: 1-D array created by enum_to_np + :param b: 1-D array created by enum_to_np + :return: True if the two arrays are equal + """ + return a[0] == b[0] + + class MGMT_COMMAND: """ Signifies the Mgmt Command being sent between two actors. These may be diff --git a/src/lava/magma/runtime/runtime.py b/src/lava/magma/runtime/runtime.py index e7f4c6b24..194096b6a 100644 --- a/src/lava/magma/runtime/runtime.py +++ b/src/lava/magma/runtime/runtime.py @@ -14,8 +14,8 @@ import MessageInfrastructureInterface from lava.magma.runtime.message_infrastructure.factory import \ MessageInfrastructureFactory -from lava.magma.runtime.mgmt_token_enums import MGMT_COMMAND, MGMT_RESPONSE, \ - enum_to_np, REQ_TYPE +from lava.magma.runtime.mgmt_token_enums import enum_to_np, enum_equal, \ + MGMT_COMMAND, MGMT_RESPONSE, REQ_TYPE from lava.magma.runtime.runtime_service import AsyncPyRuntimeService if ty.TYPE_CHECKING: @@ -227,7 +227,7 @@ def _run(self, run_condition): if run_condition.blocking: for recv_port in self.service_to_runtime_ack: data = recv_port.recv() - if not np.array_equal(data, MGMT_RESPONSE.DONE): + if not enum_equal(data, MGMT_RESPONSE.DONE): raise RuntimeError(f"Runtime Received {data}") if run_condition.blocking: self.current_ts += self.num_steps @@ -244,7 +244,7 @@ def wait(self): if self._is_running: for recv_port in self.service_to_runtime_ack: data = recv_port.recv() - if not np.array_equal(data, MGMT_RESPONSE.DONE): + if not enum_equal(data, MGMT_RESPONSE.DONE): raise RuntimeError(f"Runtime Received {data}") self.current_ts += self.num_steps self._is_running = False @@ -260,7 +260,7 @@ def stop(self): send_port.send(MGMT_COMMAND.STOP) for recv_port in self.service_to_runtime_ack: data = recv_port.recv() - if not np.array_equal(data, MGMT_RESPONSE.TERMINATED): + if not enum_equal(data, MGMT_RESPONSE.TERMINATED): raise RuntimeError(f"Runtime Received {data}") self.join() self._is_running = False diff --git a/src/lava/magma/runtime/runtime_service.py b/src/lava/magma/runtime/runtime_service.py index 9e9053a1d..e2e3e1065 100644 --- a/src/lava/magma/runtime/runtime_service.py +++ b/src/lava/magma/runtime/runtime_service.py @@ -9,9 +9,11 @@ from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort from lava.magma.core.sync.protocol import AbstractSyncProtocol from lava.magma.runtime.mgmt_token_enums import ( + enum_to_np, + enum_equal, MGMT_RESPONSE, MGMT_COMMAND, - enum_to_np, REQ_TYPE, + REQ_TYPE, ) @@ -183,23 +185,23 @@ def run(self): # Probe if there is a new command from the runtime if self.runtime_to_service_cmd.probe(): command = self.runtime_to_service_cmd.recv() - if np.array_equal(command, MGMT_COMMAND.STOP): + if enum_equal(command, MGMT_COMMAND.STOP): # Inform all ProcessModels about the STOP command self._send_pm_cmd(command) rsps = self._get_pm_resp() for rsp in rsps: - if not np.array_equal(rsp, MGMT_RESPONSE.TERMINATED): + if not enum_equal(rsp, MGMT_RESPONSE.TERMINATED): raise ValueError(f"Wrong Response Received : {rsp}") # Inform the runtime about successful termination self.service_to_runtime_ack.send(MGMT_RESPONSE.TERMINATED) self.join() return - elif np.array_equal(command, MGMT_COMMAND.PAUSE): + elif enum_equal(command, MGMT_COMMAND.PAUSE): # Inform all ProcessModels about the PAUSE command self._send_pm_cmd(command) rsps = self._get_pm_resp() for rsp in rsps: - if not np.array_equal(rsp, MGMT_RESPONSE.PAUSED): + if not enum_equal(rsp, MGMT_RESPONSE.PAUSED): raise ValueError(f"Wrong Response Received : {rsp}") # Inform the runtime about successful pausing self.service_to_runtime_ack.send(MGMT_RESPONSE.PAUSED) @@ -211,27 +213,26 @@ def run(self): phase = LoihiPyRuntimeService.Phase.HOST while True: # Check if it is the last time step - is_last_ts = np.array_equal(enum_to_np(curr_time_step), - command) + is_last_ts = enum_equal(enum_to_np(curr_time_step), + command) # Advance to the next phase phase = self._next_phase(phase, is_last_ts) # Increase time step if spiking phase - if np.array_equal(phase, - LoihiPyRuntimeService.Phase.SPK): + if enum_equal(phase, LoihiPyRuntimeService.Phase.SPK): curr_time_step += 1 # Inform ProcessModels about current phase self._send_pm_cmd(phase) # ProcessModels respond with DONE if not HOST phase - if not np.array_equal( + if not enum_equal( phase, LoihiPyRuntimeService.Phase.HOST): rsps = self._get_pm_resp() for rsp in rsps: - if not np.array_equal(rsp, MGMT_RESPONSE.DONE): + if not enum_equal(rsp, MGMT_RESPONSE.DONE): raise ValueError( f"Wrong Response Received : {rsp}") # If HOST phase (last time step ended) break the loop - if np.array_equal( + if enum_equal( phase, LoihiPyRuntimeService.Phase.HOST): break @@ -242,11 +243,11 @@ def run(self): self._handle_get_set(phase) def _handle_get_set(self, phase): - if np.array_equal(phase, LoihiPyRuntimeService.Phase.HOST): + if enum_equal(phase, LoihiPyRuntimeService.Phase.HOST): while True: if self.runtime_to_service_req.probe(): request = self.runtime_to_service_req.recv() - if np.array_equal(request, REQ_TYPE.GET): + if enum_equal(request, REQ_TYPE.GET): requests: ty.List[np.ndarray] = [request] # recv model_id model_id: int = \ @@ -260,7 +261,7 @@ def _handle_get_set(self, phase): self._relay_to_runtime_data_given_model_id( model_id) - elif np.array_equal(request, REQ_TYPE.SET): + elif enum_equal(request, REQ_TYPE.SET): requests: ty.List[np.ndarray] = [request] # recv model_id model_id: int = \ @@ -304,11 +305,11 @@ def _get_pm_resp(self) -> ty.Iterable[MGMT_RESPONSE]: def run(self): while True: command = self.runtime_to_service_cmd.recv() - if np.array_equal(command, MGMT_COMMAND.STOP): + if enum_equal(command, MGMT_COMMAND.STOP): self._send_pm_cmd(command) rsps = self._get_pm_resp() for rsp in rsps: - if not np.array_equal(rsp, MGMT_RESPONSE.TERMINATED): + if not enum_equal(rsp, MGMT_RESPONSE.TERMINATED): raise ValueError(f"Wrong Response Received : {rsp}") self.service_to_runtime_ack.send(MGMT_RESPONSE.TERMINATED) self.join() @@ -317,6 +318,6 @@ def run(self): self._send_pm_cmd(MGMT_COMMAND.RUN) rsps = self._get_pm_resp() for rsp in rsps: - if not np.array_equal(rsp, MGMT_RESPONSE.DONE): + if not enum_equal(rsp, MGMT_RESPONSE.DONE): raise ValueError(f"Wrong Response Received : {rsp}") self.service_to_runtime_ack.send(MGMT_RESPONSE.DONE) From 8471f3fe7c034b9b7ad6270be0ba9b1c4cbfb309 Mon Sep 17 00:00:00 2001 From: "Liu, Ruokun" Date: Sat, 20 Nov 2021 19:53:56 -0800 Subject: [PATCH 2/3] Use select instead of probe with busy waiting. --- .../magma/compiler/channels/pypychannel.py | 49 ++++++++++- src/lava/magma/core/model/py/model.py | 74 ++++++++-------- src/lava/magma/runtime/runtime_service.py | 88 +++++++++---------- 3 files changed, 125 insertions(+), 86 deletions(-) diff --git a/src/lava/magma/compiler/channels/pypychannel.py b/src/lava/magma/compiler/channels/pypychannel.py index f93287c2b..ca768b628 100644 --- a/src/lava/magma/compiler/channels/pypychannel.py +++ b/src/lava/magma/compiler/channels/pypychannel.py @@ -3,12 +3,12 @@ # See: https://spdx.org/licenses/ import typing as ty from queue import Queue, Empty -from threading import Thread +from threading import BoundedSemaphore, Condition, Thread from time import time from dataclasses import dataclass import numpy as np -from multiprocessing import Pipe, BoundedSemaphore +from multiprocessing import Pipe from lava.magma.compiler.channels.interfaces import ( Channel, @@ -57,6 +57,7 @@ def __init__(self, name, shm, proto, size, req, ack): self._done = False self._array = [] self._semaphore = None + self.observer = None self.thread = None @property @@ -97,7 +98,10 @@ def _ack_callback(self): try: while not self._done: self._ack.recv_bytes(0) + not_full = self.probe() self._semaphore.release() + if self.observer and not not_full: + self.observer() except EOFError: pass @@ -190,6 +194,7 @@ def __init__(self, name, shm, proto, size, req, ack): self._done = False self._array = [] self._queue = None + self.observer = None self.thread = None @property @@ -230,7 +235,10 @@ def _req_callback(self): try: while not self._done: self._req.recv_bytes(0) + not_empty = self.probe() self._queue.put_nowait(0) + if self.observer and not not_empty: + self.observer() except EOFError: pass @@ -265,6 +273,43 @@ def join(self): self._done = True +class CspSelector: + """ + Utility class to allow waiting for multiple channels to become ready + """ + + def __init__(self): + """Instantiates CspSelector object and class attributes""" + self._cv = Condition() + + def _changed(self): + with self._cv: + self._cv.notify_all() + + def _set_observer(self, channel_actions, observer): + for channel, _ in channel_actions: + channel.observer = observer + + def select( + self, + *args: ty.Tuple[ty.Union[CspSendPort, CspRecvPort], + ty.Callable[[], ty.Any] + ] + ): + """ + Wait for any channel to become ready, then execute the corresponding + callable and return the result. + """ + with self._cv: + self._set_observer(args, self._changed) + while True: + for channel, action in args: + if channel.probe(): + self._set_observer(args, None) + return action() + self._cv.wait() + + class PyPyChannel(Channel): """Helper class to create the set of send and recv port and encapsulate them inside a common structure. We call this a PyPyChannel""" diff --git a/src/lava/magma/core/model/py/model.py b/src/lava/magma/core/model/py/model.py index d6d40a7d9..fab00722c 100644 --- a/src/lava/magma/core/model/py/model.py +++ b/src/lava/magma/core/model/py/model.py @@ -6,7 +6,8 @@ import numpy as np -from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort +from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort,\ + CspSelector from lava.magma.core.model.model import AbstractProcessModel from lava.magma.core.model.py.ports import AbstractPyPort, PyVarPort from lava.magma.runtime.mgmt_token_enums import ( @@ -114,9 +115,10 @@ def run(self): (service_to_process_cmd). After calling the method of a phase of all ProcessModels the runtime service is informed about completion. The loop ends when the STOP command is received.""" + selector = CspSelector() + action = 'cmd' while True: - # Probe if there is a new command from the runtime service - if self.service_to_process_cmd.probe(): + if action == 'cmd': phase = self.service_to_process_cmd.recv() if enum_equal(phase, MGMT_COMMAND.STOP): self.process_to_service_ack.send(MGMT_RESPONSE.TERMINATED) @@ -133,9 +135,6 @@ def run(self): if self.pre_guard(): self.run_pre_mgmt() self.process_to_service_ack.send(MGMT_RESPONSE.DONE) - # Handle VarPort requests from RefPorts - if len(self.var_ports) > 0: - self._handle_var_ports() # Learning phase elif enum_equal(phase, PyLoihiProcessModel.Phase.LRN): # Enable via guard method @@ -148,15 +147,29 @@ def run(self): if self.post_guard(): self.run_post_mgmt() self.process_to_service_ack.send(MGMT_RESPONSE.DONE) - # Handle VarPort requests from RefPorts - if len(self.var_ports) > 0: - self._handle_var_ports() # Host phase - called at the last time step before STOP elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST): - # Handle get/set Var requests from runtime service - self._handle_get_set_var() + pass else: raise ValueError(f"Wrong Phase Info Received : {phase}") + elif action == 'req': + # Handle get/set Var requests from runtime service + self._handle_get_set_var() + else: + # Handle VarPort requests from RefPorts + self._handle_var_port(action) + + channel_actions = [(self.service_to_process_cmd, lambda: 'cmd')] + if enum_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT) or \ + enum_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT): + for var_port in self.var_ports: + for csp_port in var_port.csp_ports: + if isinstance(csp_port, CspRecvPort): + channel_actions.append((csp_port, lambda: var_port)) + elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST): + channel_actions.append((self.service_to_process_req, + lambda: 'req')) + action = selector.select(*channel_actions) # FIXME: (PP) might not be able to perform get/set during pause def _handle_get_set_var(self): @@ -164,21 +177,14 @@ def _handle_get_set_var(self): the corresponding handling methods. The loop ends upon a new command from runtime service after all get/set Var requests have been handled.""" - while True: - # Probe if there is a get/set Var request from runtime service - if self.service_to_process_req.probe(): - # Get the type of the request - request = self.service_to_process_req.recv() - if enum_equal(request, REQ_TYPE.GET): - self._handle_get_var() - elif enum_equal(request, REQ_TYPE.SET): - self._handle_set_var() - else: - raise RuntimeError(f"Unknown request type {request}") - - # End if another command from runtime service arrives - if self.service_to_process_cmd.probe(): - return + # Get the type of the request + request = self.service_to_process_req.recv() + if enum_equal(request, REQ_TYPE.GET): + self._handle_get_var() + elif enum_equal(request, REQ_TYPE.SET): + self._handle_set_var() + else: + raise RuntimeError(f"Unknown request type {request}") def _handle_get_var(self): """Handles the get Var command from runtime service.""" @@ -233,16 +239,6 @@ def _handle_set_var(self): else: raise RuntimeError("Unsupported type") - # TODO: (PP) use select(..) to service VarPorts instead of a loop - def _handle_var_ports(self): - """Handles read/write requests on any VarPorts. The loop ends upon a - new command from runtime service after all VarPort service requests have - been handled.""" - while True: - # Loop through read/write requests of each VarPort - for vp in self.var_ports: - vp.service() - - # End if another command from runtime service arrives - if self.service_to_process_cmd.probe(): - return + def _handle_var_port(self, var_port): + """Handles read/write requests on the given VarPort.""" + var_port.service() diff --git a/src/lava/magma/runtime/runtime_service.py b/src/lava/magma/runtime/runtime_service.py index e2e3e1065..dd7d2f924 100644 --- a/src/lava/magma/runtime/runtime_service.py +++ b/src/lava/magma/runtime/runtime_service.py @@ -6,7 +6,8 @@ import numpy as np -from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort +from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort,\ + CspSelector from lava.magma.core.sync.protocol import AbstractSyncProtocol from lava.magma.runtime.mgmt_token_enums import ( enum_to_np, @@ -132,9 +133,8 @@ def _get_pm_resp(self) -> ty.Iterable[MGMT_RESPONSE]: counter = 0 while counter < num_responses_expected: ptos_recv_port = self.process_to_service_ack[counter] - if ptos_recv_port.probe(): - rcv_msgs.append(ptos_recv_port.recv()) - counter += 1 + rcv_msgs.append(ptos_recv_port.recv()) + counter += 1 return rcv_msgs def _relay_to_runtime_data_given_model_id(self, model_id: int): @@ -180,10 +180,13 @@ def run(self): In this case iterate through the phases of the Loihi protocol until the last time step is reached. The runtime is informed after the last time step. The loop ends when receiving the STOP command from the runtime.""" + selector = CspSelector() phase = LoihiPyRuntimeService.Phase.HOST while True: # Probe if there is a new command from the runtime - if self.runtime_to_service_cmd.probe(): + cmd = selector.select((self.runtime_to_service_cmd, lambda: True), + (self.runtime_to_service_req, lambda: False)) + if cmd: command = self.runtime_to_service_cmd.recv() if enum_equal(command, MGMT_COMMAND.STOP): # Inform all ProcessModels about the STOP command @@ -238,49 +241,44 @@ def run(self): # Inform the runtime that last time step was reached self.service_to_runtime_ack.send(MGMT_RESPONSE.DONE) - - # Handle get/set Var - self._handle_get_set(phase) + else: + # Handle get/set Var + self._handle_get_set(phase) def _handle_get_set(self, phase): if enum_equal(phase, LoihiPyRuntimeService.Phase.HOST): - while True: - if self.runtime_to_service_req.probe(): - request = self.runtime_to_service_req.recv() - if enum_equal(request, REQ_TYPE.GET): - requests: ty.List[np.ndarray] = [request] - # recv model_id - model_id: int = \ - self.runtime_to_service_req.recv()[ - 0].item() - # recv var_id - requests.append( - self.runtime_to_service_req.recv()) - self._send_pm_req_given_model_id(model_id, - *requests) - - self._relay_to_runtime_data_given_model_id( - model_id) - elif enum_equal(request, REQ_TYPE.SET): - requests: ty.List[np.ndarray] = [request] - # recv model_id - model_id: int = \ - self.runtime_to_service_req.recv()[ - 0].item() - # recv var_id - requests.append( - self.runtime_to_service_req.recv()) - self._send_pm_req_given_model_id(model_id, - *requests) - - self._relay_to_pm_data_given_model_id( - model_id) - else: - raise RuntimeError( - f"Unknown request {request}") - - if self.runtime_to_service_cmd.probe(): - return + request = self.runtime_to_service_req.recv() + if enum_equal(request, REQ_TYPE.GET): + requests: ty.List[np.ndarray] = [request] + # recv model_id + model_id: int = \ + self.runtime_to_service_req.recv()[ + 0].item() + # recv var_id + requests.append( + self.runtime_to_service_req.recv()) + self._send_pm_req_given_model_id(model_id, + *requests) + + self._relay_to_runtime_data_given_model_id( + model_id) + elif enum_equal(request, REQ_TYPE.SET): + requests: ty.List[np.ndarray] = [request] + # recv model_id + model_id: int = \ + self.runtime_to_service_req.recv()[ + 0].item() + # recv var_id + requests.append( + self.runtime_to_service_req.recv()) + self._send_pm_req_given_model_id(model_id, + *requests) + + self._relay_to_pm_data_given_model_id( + model_id) + else: + raise RuntimeError( + f"Unknown request {request}") class LoihiCRuntimeService(AbstractRuntimeService): From ec1c463b9fad212e1db7d59f3b2430e85dd2ab31 Mon Sep 17 00:00:00 2001 From: "Liu, Ruokun" Date: Sun, 21 Nov 2021 12:57:19 -0800 Subject: [PATCH 3/3] Use multiprocessing semaphore instead of pipe. --- .../magma/compiler/channels/pypychannel.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/lava/magma/compiler/channels/pypychannel.py b/src/lava/magma/compiler/channels/pypychannel.py index ca768b628..0450efd5a 100644 --- a/src/lava/magma/compiler/channels/pypychannel.py +++ b/src/lava/magma/compiler/channels/pypychannel.py @@ -8,7 +8,7 @@ from dataclasses import dataclass import numpy as np -from multiprocessing import Pipe +from multiprocessing import Semaphore from lava.magma.compiler.channels.interfaces import ( Channel, @@ -42,8 +42,8 @@ def __init__(self, name, shm, proto, size, req, ack): shm : SharedMemory proto : Proto size : int - req : Pipe - ack : Pipe + req : Semaphore + ack : Semaphore """ self._name = name self._shm = shm @@ -97,7 +97,7 @@ def start(self): def _ack_callback(self): try: while not self._done: - self._ack.recv_bytes(0) + self._ack.acquire() not_full = self.probe() self._semaphore.release() if self.observer and not not_full: @@ -124,7 +124,7 @@ def send(self, data): self._semaphore.acquire() self._array[self._idx][:] = data[:] self._idx = (self._idx + 1) % self._size - self._req.send_bytes(bytes(0)) + self._req.release() def join(self): self._done = True @@ -179,8 +179,8 @@ def __init__(self, name, shm, proto, size, req, ack): shm : SharedMemory proto : Proto size : int - req : Pipe - ack : Pipe + req : Semaphore + ack : Semaphore """ self._name = name self._shm = shm @@ -234,7 +234,7 @@ def start(self): def _req_callback(self): try: while not self._done: - self._req.recv_bytes(0) + self._req.acquire() not_empty = self.probe() self._queue.put_nowait(0) if self.observer and not not_empty: @@ -265,7 +265,7 @@ def recv(self): self._queue.get() result = self._array[self._idx].copy() self._idx = (self._idx + 1) % self._size - self._ack.send_bytes(bytes(0)) + self._ack.release() return result @@ -335,11 +335,11 @@ def __init__(self, nbytes = np.prod(shape) * np.dtype(dtype).itemsize smm = message_infrastructure.smm shm = smm.SharedMemory(int(nbytes * size)) - req = Pipe(duplex=False) - ack = Pipe(duplex=False) + req = Semaphore(0) + ack = Semaphore(0) proto = Proto(shape=shape, dtype=dtype, nbytes=nbytes) - self._src_port = CspSendPort(src_name, shm, proto, size, req[1], ack[0]) - self._dst_port = CspRecvPort(dst_name, shm, proto, size, req[0], ack[1]) + self._src_port = CspSendPort(src_name, shm, proto, size, req, ack) + self._dst_port = CspRecvPort(dst_name, shm, proto, size, req, ack) @property def src_port(self) -> AbstractCspSendPort: