Skip to content

Commit

Permalink
Fixed Pause for Async Process (lava-nc#646)
Browse files Browse the repository at this point in the history
* Fixed Pause for Async Process

* Update model.py

* Modified test to check for pause
  • Loading branch information
ysingh7 authored Mar 7, 2023
1 parent 93dd97f commit 5830a6e
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 25 deletions.
49 changes: 30 additions & 19 deletions src/lava/magma/core/model/py/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class AbstractPyProcessModel(AbstractProcessModel, ABC):
"""

def __init__(
self,
proc_params: ty.Type["ProcessParameters"],
loglevel: ty.Optional[int] = logging.WARNING,
self,
proc_params: ty.Type["ProcessParameters"],
loglevel: ty.Optional[int] = logging.WARNING,
) -> None:
super().__init__(proc_params=proc_params, loglevel=loglevel)
self.model_id: ty.Optional[int] = None
Expand Down Expand Up @@ -479,14 +479,13 @@ def add_ports_for_polling(self):
Add various ports to poll for communication on ports
"""
if (
enum_equal(self.phase, PyLoihiProcessModel.Phase.PRE_MGMT)
or enum_equal(self.phase, PyLoihiProcessModel.Phase.POST_MGMT)
or enum_equal(self.phase, PyLoihiProcessModel.Phase.HOST)
enum_equal(self.phase, PyLoihiProcessModel.Phase.PRE_MGMT)
or enum_equal(self.phase, PyLoihiProcessModel.Phase.POST_MGMT)
or enum_equal(self.phase, PyLoihiProcessModel.Phase.HOST)
):
for var_port in self.var_ports:
for csp_port in var_port.csp_ports:
if isinstance(csp_port, CspRecvPort):

def func(fvar_port=var_port):
return lambda: fvar_port

Expand Down Expand Up @@ -526,6 +525,8 @@ class PyAsyncProcessModel(AbstractPyProcessModel):
def __init__(self, proc_params: ty.Optional["ProcessParameters"] = None):
super().__init__(proc_params=proc_params)
self.num_steps = 0
self._req_pause = False
self._req_stop = False
self._cmd_handlers.update({MGMT_COMMAND.RUN[0]: self._run_async})

class Response:
Expand All @@ -546,21 +547,23 @@ class Response:
REQ_STOP = enum_to_np(-5)
"""Signifies Request of STOP"""

def _pause(self):
"""
Command handler for Pause Command.
"""
pass

def check_for_stop_cmd(self) -> bool:
"""
Checks if the RS has sent a STOP command.
"""
if self.service_to_process.probe():
cmd = self.service_to_process.peek()
if enum_equal(cmd, MGMT_COMMAND.STOP):
self.service_to_process.recv()
self._stop()
return True
return False

def check_for_pause_cmd(self) -> bool:
"""
Checks if the RS has sent a PAUSE command.
"""
if self.service_to_process.probe():
cmd = self.service_to_process.peek()
if enum_equal(cmd, MGMT_COMMAND.PAUSE):
return True
return False

Expand All @@ -577,7 +580,15 @@ def _run_async(self):
"""
self.num_steps = int(self.service_to_process.recv()[0].item())
self.run_async()
self.process_to_service.send(PyAsyncProcessModel.Response.STATUS_DONE)
if self._req_stop:
self.process_to_service.send(PyAsyncProcessModel.Response.REQ_STOP)
self._req_stop = False
elif self._req_pause:
self.process_to_service.send(PyAsyncProcessModel.Response.REQ_PAUSE)
self._req_pause = False
else:
self.process_to_service.send(
PyAsyncProcessModel.Response.STATUS_DONE)

def add_ports_for_polling(self):
"""
Expand All @@ -587,7 +598,7 @@ def add_ports_for_polling(self):


def _get_attr_dict(
model_class: ty.Type[PyLoihiProcessModel],
model_class: ty.Type[PyLoihiProcessModel],
) -> ty.Dict[str, ty.Any]:
"""Get a dictionary of non-callable public attributes of a class.
Expand Down Expand Up @@ -617,7 +628,7 @@ def _get_attr_dict(


def _get_callable_dict(
model_class: ty.Type[PyLoihiProcessModel],
model_class: ty.Type[PyLoihiProcessModel],
) -> ty.Dict[str, ty.Callable]:
"""Get a dictionary of callable public members of a class.
Expand Down Expand Up @@ -648,7 +659,7 @@ def _get_callable_dict(


def PyLoihiModelToPyAsyncModel(
py_loihi_model: ty.Type[PyLoihiProcessModel],
py_loihi_model: ty.Type[PyLoihiProcessModel],
) -> ty.Type[PyAsyncProcessModel]:
"""Factory function that converts Py-Loihi process models
to equivalent Py-Async definition.
Expand Down
24 changes: 19 additions & 5 deletions src/lava/magma/runtime/runtime_services/runtime_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,19 +417,34 @@ def _send_pm_cmd(self, cmd: MGMT_COMMAND):
for stop_send_port in self.service_to_process:
stop_send_port.send(cmd)

def _get_pm_resp(self) -> ty.Iterable[MGMT_RESPONSE]:
def _get_pm_resp(self, stop=False, pause=False) -> ty.Iterable[
MGMT_RESPONSE]:
rcv_msgs = []
for ptos_recv_port in self.process_to_service:
rcv_msgs.append(ptos_recv_port.recv())
rcv_msg = ptos_recv_port.recv()
if stop or pause:
if enum_equal(
rcv_msg, LoihiPyRuntimeService.PMResponse.STATUS_DONE
):
rcv_msg = ptos_recv_port.recv()
rcv_msgs.append(rcv_msg)
return rcv_msgs

def _handle_pause(self):
# Inform the runtime about successful pausing
self._send_pm_cmd(MGMT_COMMAND.PAUSE)
rsps = self._get_pm_resp(pause=True)
for rsp in rsps:
if not enum_equal(
rsp, LoihiPyRuntimeService.PMResponse.STATUS_PAUSED
):
self.service_to_runtime.send(MGMT_RESPONSE.ERROR)
raise ValueError(f"Wrong Response Received : {rsp}")
self.service_to_runtime.send(MGMT_RESPONSE.PAUSED)

def _handle_stop(self):
self._send_pm_cmd(MGMT_COMMAND.STOP)
rsps = self._get_pm_resp()
rsps = self._get_pm_resp(stop=True)
for rsp in rsps:
if not enum_equal(
rsp, LoihiPyRuntimeService.PMResponse.STATUS_TERMINATED
Expand Down Expand Up @@ -493,8 +508,7 @@ def run(self):
):
self._error = True
if not enum_equal(resp,
AsyncPyRuntimeService.PMResponse.STATUS_DONE # noqa: E501
):
AsyncPyRuntimeService.PMResponse.STATUS_DONE): # noqa: E501
done = False
if done:
self.service_to_runtime.send(MGMT_RESPONSE.DONE)
Expand Down
4 changes: 3 additions & 1 deletion tests/lava/magma/runtime/test_async_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def run_async(self):
self.v = self.v + 1000
if self.check_for_stop_cmd():
return
if self.check_for_pause_cmd():
return


@implements(proc=AsyncProcess2, protocol=AsyncProtocol)
Expand Down Expand Up @@ -73,7 +75,7 @@ def test_async_process_model(self):

def test_async_process_model_pause(self):
"""
Verifies the working of Asynchronous Process, pause should have no
Verifies the working of Asynchronous Process, pause should have
effect
"""
process = AsyncProcess1(shape=(2, 2))
Expand Down

0 comments on commit 5830a6e

Please sign in to comment.