Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed Pause for Async Process #646

Merged
merged 4 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions src/lava/magma/core/model/py/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class AbstractPyProcessModel(AbstractProcessModel, ABC):
"""

def __init__(
self,
proc_params: ty.Type["ProcessParameters"],
loglevel: ty.Optional[int] = logging.WARNING,
self,
proc_params: ty.Type["ProcessParameters"],
loglevel: ty.Optional[int] = logging.WARNING,
) -> None:
super().__init__(proc_params=proc_params, loglevel=loglevel)
self.model_id: ty.Optional[int] = None
Expand Down Expand Up @@ -479,14 +479,13 @@ def add_ports_for_polling(self):
Add various ports to poll for communication on ports
"""
if (
enum_equal(self.phase, PyLoihiProcessModel.Phase.PRE_MGMT)
or enum_equal(self.phase, PyLoihiProcessModel.Phase.POST_MGMT)
or enum_equal(self.phase, PyLoihiProcessModel.Phase.HOST)
enum_equal(self.phase, PyLoihiProcessModel.Phase.PRE_MGMT)
or enum_equal(self.phase, PyLoihiProcessModel.Phase.POST_MGMT)
or enum_equal(self.phase, PyLoihiProcessModel.Phase.HOST)
):
for var_port in self.var_ports:
for csp_port in var_port.csp_ports:
if isinstance(csp_port, CspRecvPort):

def func(fvar_port=var_port):
return lambda: fvar_port

Expand Down Expand Up @@ -526,6 +525,8 @@ class PyAsyncProcessModel(AbstractPyProcessModel):
def __init__(self, proc_params: ty.Optional["ProcessParameters"] = None):
super().__init__(proc_params=proc_params)
self.num_steps = 0
self._req_pause = False
self._req_stop = False
self._cmd_handlers.update({MGMT_COMMAND.RUN[0]: self._run_async})

class Response:
Expand All @@ -546,21 +547,23 @@ class Response:
REQ_STOP = enum_to_np(-5)
"""Signifies Request of STOP"""

def _pause(self):
"""
Command handler for Pause Command.
"""
pass

def check_for_stop_cmd(self) -> bool:
"""
Checks if the RS has sent a STOP command.
"""
if self.service_to_process.probe():
cmd = self.service_to_process.peek()
if enum_equal(cmd, MGMT_COMMAND.STOP):
self.service_to_process.recv()
self._stop()
return True
return False

def check_for_pause_cmd(self) -> bool:
"""
Checks if the RS has sent a PAUSE command.
"""
if self.service_to_process.probe():
cmd = self.service_to_process.peek()
if enum_equal(cmd, MGMT_COMMAND.PAUSE):
return True
return False

Expand All @@ -577,7 +580,15 @@ def _run_async(self):
"""
self.num_steps = int(self.service_to_process.recv()[0].item())
self.run_async()
self.process_to_service.send(PyAsyncProcessModel.Response.STATUS_DONE)
if self._req_stop:
self.process_to_service.send(PyAsyncProcessModel.Response.REQ_STOP)
self._req_stop = False
elif self._req_pause:
self.process_to_service.send(PyAsyncProcessModel.Response.REQ_PAUSE)
self._req_pause = False
else:
self.process_to_service.send(
PyAsyncProcessModel.Response.STATUS_DONE)

def add_ports_for_polling(self):
"""
Expand All @@ -587,7 +598,7 @@ def add_ports_for_polling(self):


def _get_attr_dict(
model_class: ty.Type[PyLoihiProcessModel],
model_class: ty.Type[PyLoihiProcessModel],
) -> ty.Dict[str, ty.Any]:
"""Get a dictionary of non-callable public attributes of a class.

Expand Down Expand Up @@ -617,7 +628,7 @@ def _get_attr_dict(


def _get_callable_dict(
model_class: ty.Type[PyLoihiProcessModel],
model_class: ty.Type[PyLoihiProcessModel],
) -> ty.Dict[str, ty.Callable]:
"""Get a dictionary of callable public members of a class.

Expand Down Expand Up @@ -648,7 +659,7 @@ def _get_callable_dict(


def PyLoihiModelToPyAsyncModel(
py_loihi_model: ty.Type[PyLoihiProcessModel],
py_loihi_model: ty.Type[PyLoihiProcessModel],
) -> ty.Type[PyAsyncProcessModel]:
"""Factory function that converts Py-Loihi process models
to equivalent Py-Async definition.
Expand Down
24 changes: 19 additions & 5 deletions src/lava/magma/runtime/runtime_services/runtime_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,19 +417,34 @@ def _send_pm_cmd(self, cmd: MGMT_COMMAND):
for stop_send_port in self.service_to_process:
stop_send_port.send(cmd)

def _get_pm_resp(self) -> ty.Iterable[MGMT_RESPONSE]:
def _get_pm_resp(self, stop=False, pause=False) -> ty.Iterable[
MGMT_RESPONSE]:
rcv_msgs = []
for ptos_recv_port in self.process_to_service:
rcv_msgs.append(ptos_recv_port.recv())
rcv_msg = ptos_recv_port.recv()
if stop or pause:
if enum_equal(
rcv_msg, LoihiPyRuntimeService.PMResponse.STATUS_DONE
):
rcv_msg = ptos_recv_port.recv()
rcv_msgs.append(rcv_msg)
return rcv_msgs

def _handle_pause(self):
# Inform the runtime about successful pausing
self._send_pm_cmd(MGMT_COMMAND.PAUSE)
rsps = self._get_pm_resp(pause=True)
for rsp in rsps:
if not enum_equal(
rsp, LoihiPyRuntimeService.PMResponse.STATUS_PAUSED
):
self.service_to_runtime.send(MGMT_RESPONSE.ERROR)
raise ValueError(f"Wrong Response Received : {rsp}")
self.service_to_runtime.send(MGMT_RESPONSE.PAUSED)

def _handle_stop(self):
self._send_pm_cmd(MGMT_COMMAND.STOP)
rsps = self._get_pm_resp()
rsps = self._get_pm_resp(stop=True)
for rsp in rsps:
if not enum_equal(
rsp, LoihiPyRuntimeService.PMResponse.STATUS_TERMINATED
Expand Down Expand Up @@ -493,8 +508,7 @@ def run(self):
):
self._error = True
if not enum_equal(resp,
AsyncPyRuntimeService.PMResponse.STATUS_DONE # noqa: E501
):
AsyncPyRuntimeService.PMResponse.STATUS_DONE): # noqa: E501
done = False
if done:
self.service_to_runtime.send(MGMT_RESPONSE.DONE)
Expand Down
4 changes: 3 additions & 1 deletion tests/lava/magma/runtime/test_async_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def run_async(self):
self.v = self.v + 1000
if self.check_for_stop_cmd():
return
if self.check_for_pause_cmd():
return


@implements(proc=AsyncProcess2, protocol=AsyncProtocol)
Expand Down Expand Up @@ -73,7 +75,7 @@ def test_async_process_model(self):

def test_async_process_model_pause(self):
"""
Verifies the working of Asynchronous Process, pause should have no
Verifies the working of Asynchronous Process, pause should have
effect
"""
process = AsyncProcess1(shape=(2, 2))
Expand Down