Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Requesting Pause from Host. #373

Merged
merged 7 commits into from
Sep 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions src/lava/magma/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@

if ty.TYPE_CHECKING:
from lava.magma.core.process.process import AbstractProcess

from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort, \
CspSelector
from lava.magma.compiler.builders.channel_builder import (
ChannelBuilderMp, RuntimeChannelBuilderMp, ServiceChannelBuilderMp)
from lava.magma.compiler.builders.interfaces import AbstractProcessBuilder
Expand Down Expand Up @@ -260,12 +261,23 @@ def _get_resp_for_run(self):
Gets response from RuntimeServices
"""
if self._is_running:
for recv_port in self.service_to_runtime:
selector = CspSelector()
# Poll on all responses
channel_actions = [(recv_port, (lambda y: (lambda: y))(
recv_port)) for
recv_port in
self.service_to_runtime]
rsps = []
while True:
recv_port = selector.select(*channel_actions)
data = recv_port.recv()
rsps.append(data)
if enum_equal(data, MGMT_RESPONSE.REQ_PAUSE):
self._req_paused = True
self.pause()
return
elif enum_equal(data, MGMT_RESPONSE.REQ_STOP):
self._req_stop = True
self.stop()
return
elif not enum_equal(data, MGMT_RESPONSE.DONE):
if enum_equal(data, MGMT_RESPONSE.ERROR):
# Receive all errors from the ProcessModels
Expand All @@ -282,13 +294,9 @@ def _get_resp_for_run(self):
f"output above for details.")
else:
raise RuntimeError(f"Runtime Received {data}")
if self._req_paused:
self._req_paused = False
self.pause()
if self._req_stop:
self._req_stop = False
self.stop()
self._is_running = False
if len(rsps) == len(self.service_to_runtime):
self._is_running = False
return

def start(self, run_condition: AbstractRunCondition):
"""
Expand Down Expand Up @@ -356,6 +364,13 @@ def pause(self):
raise RuntimeError(
f"{error_cnt} Exception(s) occurred. See "
f"output above for details.")
else:
if recv_port.probe():
data = recv_port.recv()
if not enum_equal(data, MGMT_RESPONSE.PAUSED):
raise RuntimeError(
f"{data} Got Wrong Response for Pause.")

self._is_running = False

def stop(self):
Expand All @@ -367,7 +382,10 @@ def stop(self):
for recv_port in self.service_to_runtime:
data = recv_port.recv()
if not enum_equal(data, MGMT_RESPONSE.TERMINATED):
raise RuntimeError(f"Runtime Received {data}")
if recv_port.probe():
data = recv_port.recv()
if not enum_equal(data, MGMT_RESPONSE.TERMINATED):
raise RuntimeError(f"Runtime Received {data}")
self.join()
self._is_running = False
self._is_started = False
Expand Down