Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update runtime and process to be "with" context managers. #605

Merged
merged 11 commits into from
Feb 15, 2023
10 changes: 10 additions & 0 deletions src/lava/magma/core/process/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,16 @@ def __del__(self):
"""
self.stop()

def __enter__(self):
"""Required for "with" block."""
Gavinator98 marked this conversation as resolved.
Show resolved Hide resolved
pass

def __exit__(self, exc_type, exc_val, exc_tb):
"""
Stop the runtime when exiting "with" block.
"""
Gavinator98 marked this conversation as resolved.
Show resolved Hide resolved
self.stop()

def _post_init(self):
"""Called after __init__() method of any sub class via
ProcessMetaClass to finalize initialization leading to following
Expand Down
8 changes: 8 additions & 0 deletions src/lava/magma/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ def __del__(self):
if self._is_started:
self.stop()

def __enter__(self):
"""Initialize the runtime on entering a "with" block"""
Gavinator98 marked this conversation as resolved.
Show resolved Hide resolved
self.initialize()

def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop the runtime when exiting "with" block."""
self.stop()

def initialize(self, node_cfg_idx: int = 0):
"""Initializes the runtime"""
self._build_message_infrastructure()
Expand Down
114 changes: 114 additions & 0 deletions tests/lava/magma/runtime/test_context_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import unittest
Gavinator98 marked this conversation as resolved.
Show resolved Hide resolved
from time import sleep

from lava.magma.compiler.compiler import Compiler
from lava.magma.core.decorator import implements, requires
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.process.message_interface_enum import ActorType
from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.variable import Var
from lava.magma.core.resources import CPU
from lava.magma.core.run_conditions import RunContinuous, RunSteps
from lava.magma.core.run_configs import RunConfig
from lava.magma.core.sync.domain import SyncDomain
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.runtime.runtime import Runtime


class SimpleProcess(AbstractProcess):
def __init__(self, **kwargs):
super().__init__()
shape = kwargs["shape"]
self.u = Var(shape=shape, init=0)
self.v = Var(shape=shape, init=0)


@implements(proc=SimpleProcess, protocol=LoihiProtocol)
@requires(CPU)
class SimpleProcessModel(PyLoihiProcessModel):
"""
Defines a SimpleProcessModel
"""
Gavinator98 marked this conversation as resolved.
Show resolved Hide resolved
u = LavaPyType(int, int)
v = LavaPyType(int, int)


class SimpleRunConfig(RunConfig):
"""
Defines a simple run config
"""
Gavinator98 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, **kwargs):
sync_domains = kwargs.pop("sync_domains")
super().__init__(custom_sync_domains=sync_domains)
self.model = None
if "model" in kwargs:
self.model = kwargs.pop("model")

def select(self, process, proc_models):
if self.model is not None:
if self.model == "sub" and isinstance(process, SimpleProcess):
return proc_models[1]

return proc_models[0]


class TestContextManager(unittest.TestCase):
def tearDown(self) -> None:
"""
Ensures process/runtime is stopped if context manager fails to
"""
self.stoppable.stop()

def test_context_manager_stops_process(self):
"""
Verifies context manager stops process when exiting "with" block
"""
process = SimpleProcess(shape=(2, 2))
self.stoppable = process
Gavinator98 marked this conversation as resolved.
Show resolved Hide resolved
simple_sync_domain = SyncDomain("simple", LoihiProtocol(), [process])
run_config = SimpleRunConfig(sync_domains=[simple_sync_domain])

with process:
process.run(condition=RunContinuous(), run_cfg=run_config)
self.assertTrue(process.runtime._is_running)
self.assertTrue(process.runtime._is_started)
sleep(2)

self.assertFalse(process.runtime._is_running)
self.assertFalse(process.runtime._is_started)

def test_context_manager_stops_runtime(self):
"""
Verifies context manager stops runtime when exiting "with" block
"""
self.process = SimpleProcess(shape=(2, 2))
Gavinator98 marked this conversation as resolved.
Show resolved Hide resolved
simple_sync_domain = SyncDomain("simple", LoihiProtocol(),
[self.process])
run_config = SimpleRunConfig(sync_domains=[simple_sync_domain])
compiler = Compiler()
executable = compiler.compile(self.process, run_config)
runtime = Runtime(executable,
ActorType.MultiProcessing)
executable.assign_runtime_to_all_processes(runtime)

self.stoppable = runtime

with runtime:
self.assertTrue(runtime._is_initialized)
self.assertFalse(runtime._is_running)
self.assertFalse(runtime._is_started)

runtime.start(run_condition=RunContinuous())

self.assertTrue(runtime._is_running)
self.assertTrue(runtime._is_started)
sleep(2)

self.assertFalse(runtime._is_running)
self.assertFalse(runtime._is_started)


if __name__ == '__main__':
unittest.main()