Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements #87

Merged
merged 3 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 59 additions & 14 deletions src/lava/magma/compiler/channels/pypychannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# See: https://spdx.org/licenses/
import typing as ty
from queue import Queue, Empty
from threading import Thread
from threading import BoundedSemaphore, Condition, Thread
from time import time
from dataclasses import dataclass

import numpy as np
from multiprocessing import Pipe, BoundedSemaphore
from multiprocessing import Semaphore

from lava.magma.compiler.channels.interfaces import (
Channel,
Expand Down Expand Up @@ -42,8 +42,8 @@ def __init__(self, name, shm, proto, size, req, ack):
shm : SharedMemory
proto : Proto
size : int
req : Pipe
ack : Pipe
req : Semaphore
ack : Semaphore
"""
self._name = name
self._shm = shm
Expand All @@ -57,6 +57,7 @@ def __init__(self, name, shm, proto, size, req, ack):
self._done = False
self._array = []
self._semaphore = None
self.observer = None
self.thread = None

@property
Expand Down Expand Up @@ -96,8 +97,11 @@ def start(self):
def _ack_callback(self):
try:
while not self._done:
self._ack.recv_bytes(0)
self._ack.acquire()
not_full = self.probe()
self._semaphore.release()
if self.observer and not not_full:
self.observer()
except EOFError:
pass

Expand All @@ -120,7 +124,7 @@ def send(self, data):
self._semaphore.acquire()
self._array[self._idx][:] = data[:]
self._idx = (self._idx + 1) % self._size
self._req.send_bytes(bytes(0))
self._req.release()

def join(self):
self._done = True
Expand Down Expand Up @@ -175,8 +179,8 @@ def __init__(self, name, shm, proto, size, req, ack):
shm : SharedMemory
proto : Proto
size : int
req : Pipe
ack : Pipe
req : Semaphore
ack : Semaphore
"""
self._name = name
self._shm = shm
Expand All @@ -190,6 +194,7 @@ def __init__(self, name, shm, proto, size, req, ack):
self._done = False
self._array = []
self._queue = None
self.observer = None
self.thread = None

@property
Expand Down Expand Up @@ -229,8 +234,11 @@ def start(self):
def _req_callback(self):
try:
while not self._done:
self._req.recv_bytes(0)
self._req.acquire()
not_empty = self.probe()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop is still busy waiting, isn't it? Is this a performance concern? Could this also be improved in performance by using a condition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There no busy waiting here, because self._req.recv_bytes() would have blocked.

self._queue.put_nowait(0)
if self.observer and not not_empty:
self.observer()
except EOFError:
pass

Expand All @@ -257,14 +265,51 @@ def recv(self):
self._queue.get()
result = self._array[self._idx].copy()
self._idx = (self._idx + 1) % self._size
self._ack.send_bytes(bytes(0))
self._ack.release()

return result

def join(self):
self._done = True


class CspSelector:
"""
Utility class to allow waiting for multiple channels to become ready
"""

def __init__(self):
"""Instantiates CspSelector object and class attributes"""
self._cv = Condition()

def _changed(self):
with self._cv:
self._cv.notify_all()

def _set_observer(self, channel_actions, observer):
for channel, _ in channel_actions:
channel.observer = observer

def select(
self,
*args: ty.Tuple[ty.Union[CspSendPort, CspRecvPort],
ty.Callable[[], ty.Any]
]
):
"""
Wait for any channel to become ready, then execute the corresponding
callable and return the result.
"""
with self._cv:
self._set_observer(args, self._changed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to set and unset this all the time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In an earlier version, all channels notified, and _changed() determined if the source channel is part of the current set being selected on. (This is also the reason for the now unused channel argument, which I didn't bother to remove.) I thought doing it this way was cheaper and cleaner.

while True:
for channel, action in args:
if channel.probe():
self._set_observer(args, None)
return action()
self._cv.wait()


class PyPyChannel(Channel):
"""Helper class to create the set of send and recv port and encapsulate
them inside a common structure. We call this a PyPyChannel"""
Expand All @@ -290,11 +335,11 @@ def __init__(self,
nbytes = np.prod(shape) * np.dtype(dtype).itemsize
smm = message_infrastructure.smm
shm = smm.SharedMemory(int(nbytes * size))
req = Pipe(duplex=False)
ack = Pipe(duplex=False)
req = Semaphore(0)
ack = Semaphore(0)
proto = Proto(shape=shape, dtype=dtype, nbytes=nbytes)
self._src_port = CspSendPort(src_name, shm, proto, size, req[1], ack[0])
self._dst_port = CspRecvPort(dst_name, shm, proto, size, req[0], ack[1])
self._src_port = CspSendPort(src_name, shm, proto, size, req, ack)
self._dst_port = CspRecvPort(dst_name, shm, proto, size, req, ack)

@property
def src_port(self) -> AbstractCspSendPort:
Expand Down
87 changes: 42 additions & 45 deletions src/lava/magma/core/model/py/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

import numpy as np

from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort
from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort,\
CspSelector
from lava.magma.core.model.model import AbstractProcessModel
from lava.magma.core.model.py.ports import AbstractPyPort, PyVarPort
from lava.magma.runtime.mgmt_token_enums import (
enum_to_np,
enum_equal,
MGMT_COMMAND,
MGMT_RESPONSE, REQ_TYPE,
)
Expand Down Expand Up @@ -113,71 +115,76 @@ def run(self):
(service_to_process_cmd). After calling the method of a phase of all
ProcessModels the runtime service is informed about completion. The
loop ends when the STOP command is received."""
selector = CspSelector()
action = 'cmd'
while True:
# Probe if there is a new command from the runtime service
if self.service_to_process_cmd.probe():
if action == 'cmd':
phase = self.service_to_process_cmd.recv()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so one key change is that instead of constantly calling probe(), we immediately block on the command channel until we receive anything.

if np.array_equal(phase, MGMT_COMMAND.STOP):
if enum_equal(phase, MGMT_COMMAND.STOP):
self.process_to_service_ack.send(MGMT_RESPONSE.TERMINATED)
self.join()
return
# Spiking phase - increase time step
if np.array_equal(phase, PyLoihiProcessModel.Phase.SPK):
if enum_equal(phase, PyLoihiProcessModel.Phase.SPK):
self.current_ts += 1
self.run_spk()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Pre-management phase
elif np.array_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT):
elif enum_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT):
# Enable via guard method
if self.pre_guard():
self.run_pre_mgmt()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Handle VarPort requests from RefPorts
if len(self.var_ports) > 0:
self._handle_var_ports()
# Learning phase
elif np.array_equal(phase, PyLoihiProcessModel.Phase.LRN):
elif enum_equal(phase, PyLoihiProcessModel.Phase.LRN):
# Enable via guard method
if self.lrn_guard():
self.run_lrn()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Post-management phase
elif np.array_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT):
elif enum_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT):
# Enable via guard method
if self.post_guard():
self.run_post_mgmt()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Handle VarPort requests from RefPorts
if len(self.var_ports) > 0:
self._handle_var_ports()
# Host phase - called at the last time step before STOP
elif np.array_equal(phase, PyLoihiProcessModel.Phase.HOST):
# Handle get/set Var requests from runtime service
self._handle_get_set_var()
elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST):
pass
else:
raise ValueError(f"Wrong Phase Info Received : {phase}")
elif action == 'req':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea was to treat everything as a command. May it be RUN, PAUSE, STOP, GET, SET, SPK, ... and just handle them all consistently.

# Handle get/set Var requests from runtime service
self._handle_get_set_var()
else:
# Handle VarPort requests from RefPorts
self._handle_var_port(action)

channel_actions = [(self.service_to_process_cmd, lambda: 'cmd')]
if enum_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT) or \
enum_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT):
for var_port in self.var_ports:
for csp_port in var_port.csp_ports:
if isinstance(csp_port, CspRecvPort):
channel_actions.append((csp_port, lambda: var_port))
elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST):
channel_actions.append((self.service_to_process_req,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a type perspective it is a bit inconsistent to give either a string or a port back as an action.

lambda: 'req'))
action = selector.select(*channel_actions)

# FIXME: (PP) might not be able to perform get/set during pause
def _handle_get_set_var(self):
"""Handles all get/set Var requests from the runtime service and calls
the corresponding handling methods. The loop ends upon a
new command from runtime service after all get/set Var requests have
been handled."""
while True:
# Probe if there is a get/set Var request from runtime service
if self.service_to_process_req.probe():
# Get the type of the request
request = self.service_to_process_req.recv()
if np.array_equal(request, REQ_TYPE.GET):
self._handle_get_var()
elif np.array_equal(request, REQ_TYPE.SET):
self._handle_set_var()
else:
raise RuntimeError(f"Unknown request type {request}")

# End if another command from runtime service arrives
if self.service_to_process_cmd.probe():
return
# Get the type of the request
request = self.service_to_process_req.recv()
if enum_equal(request, REQ_TYPE.GET):
self._handle_get_var()
elif enum_equal(request, REQ_TYPE.SET):
self._handle_set_var()
else:
raise RuntimeError(f"Unknown request type {request}")

def _handle_get_var(self):
"""Handles the get Var command from runtime service."""
Expand Down Expand Up @@ -232,16 +239,6 @@ def _handle_set_var(self):
else:
raise RuntimeError("Unsupported type")

# TODO: (PP) use select(..) to service VarPorts instead of a loop
def _handle_var_ports(self):
"""Handles read/write requests on any VarPorts. The loop ends upon a
new command from runtime service after all VarPort service requests have
been handled."""
while True:
# Loop through read/write requests of each VarPort
for vp in self.var_ports:
vp.service()

# End if another command from runtime service arrives
if self.service_to_process_cmd.probe():
return
def _handle_var_port(self, var_port):
"""Handles read/write requests on the given VarPort."""
var_port.service()
6 changes: 3 additions & 3 deletions src/lava/magma/core/model/py/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from lava.magma.compiler.channels.interfaces import AbstractCspPort
from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort
from lava.magma.core.model.interfaces import AbstractPortImplementation
from lava.magma.runtime.mgmt_token_enums import enum_to_np
from lava.magma.runtime.mgmt_token_enums import enum_to_np, enum_equal


class AbstractPyPort(AbstractPortImplementation):
Expand Down Expand Up @@ -301,10 +301,10 @@ def service(self):
cmd = enum_to_np(self._csp_recv_port.recv()[0])

# Set the value of the Var with the given data
if np.array_equal(cmd, VarPortCmd.SET):
if enum_equal(cmd, VarPortCmd.SET):
data = self._csp_recv_port.recv()
setattr(self._process_model, self.var_name, data)
elif np.array_equal(cmd, VarPortCmd.GET):
elif enum_equal(cmd, VarPortCmd.GET):
data = getattr(self._process_model, self.var_name)
self._csp_send_port.send(data)
else:
Expand Down
11 changes: 11 additions & 0 deletions src/lava/magma/runtime/mgmt_token_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ def enum_to_np(value: ty.Union[int, float],
return np.array([value], dtype=d_type)


def enum_equal(a: np.array, b: np.array) -> bool:
"""
Helper function to compare two np arrays created by enum_to_np.

:param a: 1-D array created by enum_to_np
:param b: 1-D array created by enum_to_np
:return: True if the two arrays are equal
"""
return a[0] == b[0]
mgkwill marked this conversation as resolved.
Show resolved Hide resolved


class MGMT_COMMAND:
"""
Signifies the Mgmt Command being sent between two actors. These may be
Expand Down
10 changes: 5 additions & 5 deletions src/lava/magma/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import MessageInfrastructureInterface
from lava.magma.runtime.message_infrastructure.factory import \
MessageInfrastructureFactory
from lava.magma.runtime.mgmt_token_enums import MGMT_COMMAND, MGMT_RESPONSE, \
enum_to_np, REQ_TYPE
from lava.magma.runtime.mgmt_token_enums import enum_to_np, enum_equal, \
MGMT_COMMAND, MGMT_RESPONSE, REQ_TYPE
from lava.magma.runtime.runtime_service import AsyncPyRuntimeService

if ty.TYPE_CHECKING:
Expand Down Expand Up @@ -227,7 +227,7 @@ def _run(self, run_condition):
if run_condition.blocking:
for recv_port in self.service_to_runtime_ack:
data = recv_port.recv()
if not np.array_equal(data, MGMT_RESPONSE.DONE):
if not enum_equal(data, MGMT_RESPONSE.DONE):
raise RuntimeError(f"Runtime Received {data}")
if run_condition.blocking:
self.current_ts += self.num_steps
Expand All @@ -244,7 +244,7 @@ def wait(self):
if self._is_running:
for recv_port in self.service_to_runtime_ack:
data = recv_port.recv()
if not np.array_equal(data, MGMT_RESPONSE.DONE):
if not enum_equal(data, MGMT_RESPONSE.DONE):
raise RuntimeError(f"Runtime Received {data}")
self.current_ts += self.num_steps
self._is_running = False
Expand All @@ -260,7 +260,7 @@ def stop(self):
send_port.send(MGMT_COMMAND.STOP)
for recv_port in self.service_to_runtime_ack:
data = recv_port.recv()
if not np.array_equal(data, MGMT_RESPONSE.TERMINATED):
if not enum_equal(data, MGMT_RESPONSE.TERMINATED):
raise RuntimeError(f"Runtime Received {data}")
self.join()
self._is_running = False
Expand Down
Loading