Skip to content

Commit

Permalink
Performance improvements (lava-nc#87)
Browse files Browse the repository at this point in the history
* Use specialized np.array_equal for performance.

* Use select instead of probe with busy waiting.

* Use multiprocessing semaphore instead of pipe.
  • Loading branch information
harryliu-intel authored Nov 22, 2021
1 parent b059bc2 commit b294baa
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 128 deletions.
73 changes: 59 additions & 14 deletions src/lava/magma/compiler/channels/pypychannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# See: https://spdx.org/licenses/
import typing as ty
from queue import Queue, Empty
from threading import Thread
from threading import BoundedSemaphore, Condition, Thread
from time import time
from dataclasses import dataclass

import numpy as np
from multiprocessing import Pipe, BoundedSemaphore
from multiprocessing import Semaphore

from lava.magma.compiler.channels.interfaces import (
Channel,
Expand Down Expand Up @@ -42,8 +42,8 @@ def __init__(self, name, shm, proto, size, req, ack):
shm : SharedMemory
proto : Proto
size : int
req : Pipe
ack : Pipe
req : Semaphore
ack : Semaphore
"""
self._name = name
self._shm = shm
Expand All @@ -57,6 +57,7 @@ def __init__(self, name, shm, proto, size, req, ack):
self._done = False
self._array = []
self._semaphore = None
self.observer = None
self.thread = None

@property
Expand Down Expand Up @@ -96,8 +97,11 @@ def start(self):
def _ack_callback(self):
try:
while not self._done:
self._ack.recv_bytes(0)
self._ack.acquire()
not_full = self.probe()
self._semaphore.release()
if self.observer and not not_full:
self.observer()
except EOFError:
pass

Expand All @@ -120,7 +124,7 @@ def send(self, data):
self._semaphore.acquire()
self._array[self._idx][:] = data[:]
self._idx = (self._idx + 1) % self._size
self._req.send_bytes(bytes(0))
self._req.release()

def join(self):
self._done = True
Expand Down Expand Up @@ -175,8 +179,8 @@ def __init__(self, name, shm, proto, size, req, ack):
shm : SharedMemory
proto : Proto
size : int
req : Pipe
ack : Pipe
req : Semaphore
ack : Semaphore
"""
self._name = name
self._shm = shm
Expand All @@ -190,6 +194,7 @@ def __init__(self, name, shm, proto, size, req, ack):
self._done = False
self._array = []
self._queue = None
self.observer = None
self.thread = None

@property
Expand Down Expand Up @@ -229,8 +234,11 @@ def start(self):
def _req_callback(self):
try:
while not self._done:
self._req.recv_bytes(0)
self._req.acquire()
not_empty = self.probe()
self._queue.put_nowait(0)
if self.observer and not not_empty:
self.observer()
except EOFError:
pass

Expand All @@ -257,14 +265,51 @@ def recv(self):
self._queue.get()
result = self._array[self._idx].copy()
self._idx = (self._idx + 1) % self._size
self._ack.send_bytes(bytes(0))
self._ack.release()

return result

def join(self):
self._done = True


class CspSelector:
"""
Utility class to allow waiting for multiple channels to become ready
"""

def __init__(self):
"""Instantiates CspSelector object and class attributes"""
self._cv = Condition()

def _changed(self):
with self._cv:
self._cv.notify_all()

def _set_observer(self, channel_actions, observer):
for channel, _ in channel_actions:
channel.observer = observer

def select(
self,
*args: ty.Tuple[ty.Union[CspSendPort, CspRecvPort],
ty.Callable[[], ty.Any]
]
):
"""
Wait for any channel to become ready, then execute the corresponding
callable and return the result.
"""
with self._cv:
self._set_observer(args, self._changed)
while True:
for channel, action in args:
if channel.probe():
self._set_observer(args, None)
return action()
self._cv.wait()


class PyPyChannel(Channel):
"""Helper class to create the set of send and recv port and encapsulate
them inside a common structure. We call this a PyPyChannel"""
Expand All @@ -290,11 +335,11 @@ def __init__(self,
nbytes = np.prod(shape) * np.dtype(dtype).itemsize
smm = message_infrastructure.smm
shm = smm.SharedMemory(int(nbytes * size))
req = Pipe(duplex=False)
ack = Pipe(duplex=False)
req = Semaphore(0)
ack = Semaphore(0)
proto = Proto(shape=shape, dtype=dtype, nbytes=nbytes)
self._src_port = CspSendPort(src_name, shm, proto, size, req[1], ack[0])
self._dst_port = CspRecvPort(dst_name, shm, proto, size, req[0], ack[1])
self._src_port = CspSendPort(src_name, shm, proto, size, req, ack)
self._dst_port = CspRecvPort(dst_name, shm, proto, size, req, ack)

@property
def src_port(self) -> AbstractCspSendPort:
Expand Down
87 changes: 42 additions & 45 deletions src/lava/magma/core/model/py/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

import numpy as np

from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort
from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort,\
CspSelector
from lava.magma.core.model.model import AbstractProcessModel
from lava.magma.core.model.py.ports import AbstractPyPort, PyVarPort
from lava.magma.runtime.mgmt_token_enums import (
enum_to_np,
enum_equal,
MGMT_COMMAND,
MGMT_RESPONSE, REQ_TYPE,
)
Expand Down Expand Up @@ -113,71 +115,76 @@ def run(self):
(service_to_process_cmd). After calling the method of a phase of all
ProcessModels the runtime service is informed about completion. The
loop ends when the STOP command is received."""
selector = CspSelector()
action = 'cmd'
while True:
# Probe if there is a new command from the runtime service
if self.service_to_process_cmd.probe():
if action == 'cmd':
phase = self.service_to_process_cmd.recv()
if np.array_equal(phase, MGMT_COMMAND.STOP):
if enum_equal(phase, MGMT_COMMAND.STOP):
self.process_to_service_ack.send(MGMT_RESPONSE.TERMINATED)
self.join()
return
# Spiking phase - increase time step
if np.array_equal(phase, PyLoihiProcessModel.Phase.SPK):
if enum_equal(phase, PyLoihiProcessModel.Phase.SPK):
self.current_ts += 1
self.run_spk()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Pre-management phase
elif np.array_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT):
elif enum_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT):
# Enable via guard method
if self.pre_guard():
self.run_pre_mgmt()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Handle VarPort requests from RefPorts
if len(self.var_ports) > 0:
self._handle_var_ports()
# Learning phase
elif np.array_equal(phase, PyLoihiProcessModel.Phase.LRN):
elif enum_equal(phase, PyLoihiProcessModel.Phase.LRN):
# Enable via guard method
if self.lrn_guard():
self.run_lrn()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Post-management phase
elif np.array_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT):
elif enum_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT):
# Enable via guard method
if self.post_guard():
self.run_post_mgmt()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Handle VarPort requests from RefPorts
if len(self.var_ports) > 0:
self._handle_var_ports()
# Host phase - called at the last time step before STOP
elif np.array_equal(phase, PyLoihiProcessModel.Phase.HOST):
# Handle get/set Var requests from runtime service
self._handle_get_set_var()
elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST):
pass
else:
raise ValueError(f"Wrong Phase Info Received : {phase}")
elif action == 'req':
# Handle get/set Var requests from runtime service
self._handle_get_set_var()
else:
# Handle VarPort requests from RefPorts
self._handle_var_port(action)

channel_actions = [(self.service_to_process_cmd, lambda: 'cmd')]
if enum_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT) or \
enum_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT):
for var_port in self.var_ports:
for csp_port in var_port.csp_ports:
if isinstance(csp_port, CspRecvPort):
channel_actions.append((csp_port, lambda: var_port))
elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST):
channel_actions.append((self.service_to_process_req,
lambda: 'req'))
action = selector.select(*channel_actions)

# FIXME: (PP) might not be able to perform get/set during pause
def _handle_get_set_var(self):
"""Handles all get/set Var requests from the runtime service and calls
the corresponding handling methods. The loop ends upon a
new command from runtime service after all get/set Var requests have
been handled."""
while True:
# Probe if there is a get/set Var request from runtime service
if self.service_to_process_req.probe():
# Get the type of the request
request = self.service_to_process_req.recv()
if np.array_equal(request, REQ_TYPE.GET):
self._handle_get_var()
elif np.array_equal(request, REQ_TYPE.SET):
self._handle_set_var()
else:
raise RuntimeError(f"Unknown request type {request}")

# End if another command from runtime service arrives
if self.service_to_process_cmd.probe():
return
# Get the type of the request
request = self.service_to_process_req.recv()
if enum_equal(request, REQ_TYPE.GET):
self._handle_get_var()
elif enum_equal(request, REQ_TYPE.SET):
self._handle_set_var()
else:
raise RuntimeError(f"Unknown request type {request}")

def _handle_get_var(self):
"""Handles the get Var command from runtime service."""
Expand Down Expand Up @@ -232,16 +239,6 @@ def _handle_set_var(self):
else:
raise RuntimeError("Unsupported type")

# TODO: (PP) use select(..) to service VarPorts instead of a loop
def _handle_var_ports(self):
"""Handles read/write requests on any VarPorts. The loop ends upon a
new command from runtime service after all VarPort service requests have
been handled."""
while True:
# Loop through read/write requests of each VarPort
for vp in self.var_ports:
vp.service()

# End if another command from runtime service arrives
if self.service_to_process_cmd.probe():
return
def _handle_var_port(self, var_port):
"""Handles read/write requests on the given VarPort."""
var_port.service()
6 changes: 3 additions & 3 deletions src/lava/magma/core/model/py/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from lava.magma.compiler.channels.interfaces import AbstractCspPort
from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort
from lava.magma.core.model.interfaces import AbstractPortImplementation
from lava.magma.runtime.mgmt_token_enums import enum_to_np
from lava.magma.runtime.mgmt_token_enums import enum_to_np, enum_equal


class AbstractPyPort(AbstractPortImplementation):
Expand Down Expand Up @@ -301,10 +301,10 @@ def service(self):
cmd = enum_to_np(self._csp_recv_port.recv()[0])

# Set the value of the Var with the given data
if np.array_equal(cmd, VarPortCmd.SET):
if enum_equal(cmd, VarPortCmd.SET):
data = self._csp_recv_port.recv()
setattr(self._process_model, self.var_name, data)
elif np.array_equal(cmd, VarPortCmd.GET):
elif enum_equal(cmd, VarPortCmd.GET):
data = getattr(self._process_model, self.var_name)
self._csp_send_port.send(data)
else:
Expand Down
11 changes: 11 additions & 0 deletions src/lava/magma/runtime/mgmt_token_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ def enum_to_np(value: ty.Union[int, float],
return np.array([value], dtype=d_type)


def enum_equal(a: np.array, b: np.array) -> bool:
"""
Helper function to compare two np arrays created by enum_to_np.
:param a: 1-D array created by enum_to_np
:param b: 1-D array created by enum_to_np
:return: True if the two arrays are equal
"""
return a[0] == b[0]


class MGMT_COMMAND:
"""
Signifies the Mgmt Command being sent between two actors. These may be
Expand Down
10 changes: 5 additions & 5 deletions src/lava/magma/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import MessageInfrastructureInterface
from lava.magma.runtime.message_infrastructure.factory import \
MessageInfrastructureFactory
from lava.magma.runtime.mgmt_token_enums import MGMT_COMMAND, MGMT_RESPONSE, \
enum_to_np, REQ_TYPE
from lava.magma.runtime.mgmt_token_enums import enum_to_np, enum_equal, \
MGMT_COMMAND, MGMT_RESPONSE, REQ_TYPE
from lava.magma.runtime.runtime_service import AsyncPyRuntimeService

if ty.TYPE_CHECKING:
Expand Down Expand Up @@ -227,7 +227,7 @@ def _run(self, run_condition):
if run_condition.blocking:
for recv_port in self.service_to_runtime_ack:
data = recv_port.recv()
if not np.array_equal(data, MGMT_RESPONSE.DONE):
if not enum_equal(data, MGMT_RESPONSE.DONE):
raise RuntimeError(f"Runtime Received {data}")
if run_condition.blocking:
self.current_ts += self.num_steps
Expand All @@ -244,7 +244,7 @@ def wait(self):
if self._is_running:
for recv_port in self.service_to_runtime_ack:
data = recv_port.recv()
if not np.array_equal(data, MGMT_RESPONSE.DONE):
if not enum_equal(data, MGMT_RESPONSE.DONE):
raise RuntimeError(f"Runtime Received {data}")
self.current_ts += self.num_steps
self._is_running = False
Expand All @@ -260,7 +260,7 @@ def stop(self):
send_port.send(MGMT_COMMAND.STOP)
for recv_port in self.service_to_runtime_ack:
data = recv_port.recv()
if not np.array_equal(data, MGMT_RESPONSE.TERMINATED):
if not enum_equal(data, MGMT_RESPONSE.TERMINATED):
raise RuntimeError(f"Runtime Received {data}")
self.join()
self._is_running = False
Expand Down
Loading

0 comments on commit b294baa

Please sign in to comment.