diff --git a/src/lava/magma/runtime/runtime.py b/src/lava/magma/runtime/runtime.py index 56abaea20..32abfed43 100644 --- a/src/lava/magma/runtime/runtime.py +++ b/src/lava/magma/runtime/runtime.py @@ -25,7 +25,8 @@ if ty.TYPE_CHECKING: from lava.magma.core.process.process import AbstractProcess - +from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort, \ + CspSelector from lava.magma.compiler.builders.channel_builder import ( ChannelBuilderMp, RuntimeChannelBuilderMp, ServiceChannelBuilderMp) from lava.magma.compiler.builders.interfaces import AbstractProcessBuilder @@ -260,12 +261,23 @@ def _get_resp_for_run(self): Gets response from RuntimeServices """ if self._is_running: - for recv_port in self.service_to_runtime: + selector = CspSelector() + # Poll on all responses + channel_actions = [(recv_port, (lambda y: (lambda: y))( + recv_port)) for + recv_port in + self.service_to_runtime] + rsps = [] + while True: + recv_port = selector.select(*channel_actions) data = recv_port.recv() + rsps.append(data) if enum_equal(data, MGMT_RESPONSE.REQ_PAUSE): - self._req_paused = True + self.pause() + return elif enum_equal(data, MGMT_RESPONSE.REQ_STOP): - self._req_stop = True + self.stop() + return elif not enum_equal(data, MGMT_RESPONSE.DONE): if enum_equal(data, MGMT_RESPONSE.ERROR): # Receive all errors from the ProcessModels @@ -282,13 +294,9 @@ def _get_resp_for_run(self): f"output above for details.") else: raise RuntimeError(f"Runtime Received {data}") - if self._req_paused: - self._req_paused = False - self.pause() - if self._req_stop: - self._req_stop = False - self.stop() - self._is_running = False + if len(rsps) == len(self.service_to_runtime): + self._is_running = False + return def start(self, run_condition: AbstractRunCondition): """ @@ -356,6 +364,13 @@ def pause(self): raise RuntimeError( f"{error_cnt} Exception(s) occurred. See " f"output above for details.") + else: + if recv_port.probe(): + data = recv_port.recv() + if not enum_equal(data, MGMT_RESPONSE.PAUSED): + raise RuntimeError( + f"{data} Got Wrong Response for Pause.") + self._is_running = False def stop(self): @@ -367,7 +382,10 @@ def stop(self): for recv_port in self.service_to_runtime: data = recv_port.recv() if not enum_equal(data, MGMT_RESPONSE.TERMINATED): - raise RuntimeError(f"Runtime Received {data}") + if recv_port.probe(): + data = recv_port.recv() + if not enum_equal(data, MGMT_RESPONSE.TERMINATED): + raise RuntimeError(f"Runtime Received {data}") self.join() self._is_running = False self._is_started = False