Skip to content

Commit

Permalink
Serialization (#732)
Browse files Browse the repository at this point in the history
* serialization first try

* first try

* serialization implementation + unittests

* fix linting

* fix bandit

* fix unittest

* fix codacy

* added tutorial

* Update tutorial11_serialization.ipynb

* added notebook to unit tests

* Fixed broken link in tutorial.

---------

Co-authored-by: Mathis Richter <[email protected]>
  • Loading branch information
PhilippPlank and mathisrichter authored Jul 19, 2023
1 parent d01c9c3 commit 87d6a5a
Show file tree
Hide file tree
Showing 5 changed files with 612 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/lava/magma/core/process/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def run(self,
self.create_runtime(run_cfg, compile_config)
self._runtime.start(condition)

def create_runtime(self, run_cfg: RunConfig,
def create_runtime(self, run_cfg: ty.Optional[RunConfig] = None,
executable: ty.Optional[Executable] = None,
compile_config:
ty.Optional[ty.Dict[str, ty.Any]] = None):
"""Creates a runtime for this process and all connected processes by
Expand All @@ -369,7 +370,8 @@ def create_runtime(self, run_cfg: RunConfig,
compile_config: Dict[str, Any], optional
Configuration options for the Compiler and SubCompilers.
"""
executable = self.compile(run_cfg, compile_config)
if executable is None:
executable = self.compile(run_cfg, compile_config)
self._runtime = Runtime(executable,
ActorType.MultiProcessing,
loglevel=self._log_config.level)
Expand Down Expand Up @@ -612,3 +614,9 @@ def __next__(self):
return getattr(self, self.member_names[self._iterator])
self._iterator = -1
raise StopIteration

def __getstate__(self):
return self.__dict__

def __setstate__(self, d):
self.__dict__ = d
128 changes: 128 additions & 0 deletions src/lava/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import pickle # noqa: S403 # nosec
import typing as ty
import os

from lava.magma.core.process.process import AbstractProcess
from lava.magma.compiler.executable import Executable


class SerializationObject:
"""This class is used to serialize a process or a list of processes
together with a corresponding executable.
Parameters
----------
processes: AbstractProcess, ty.List[AbstractProcess]
A process or a list of processes which should be stored in a file.
executable: Executable, optional
The corresponding executable of the compiled processes which should be
stored in a file.
"""
def __init__(self,
processes: ty.Union[AbstractProcess,
ty.List[AbstractProcess]],
executable: ty.Optional[Executable] = None) -> None:

self.processes = processes
self.executable = executable


def save(processes: ty.Union[AbstractProcess, ty.List[AbstractProcess]],
filename: str,
executable: ty.Optional[Executable] = None) -> None:
"""Saves a given process or list of processes with an (optional)
corresponding executable to file <filename>.
Parameters
----------
processes: AbstractProcess, ty.List[AbstractProcess]
A process or a list of processes which should be stored in a file.
filename: str
The path + name of the file. If no file extension is given,
'.pickle' will be added automatically.
executable: Executable, optional
The corresponding executable of the compiled processes which should be
stored in a file.
Raises
------
TypeError
If argument <process> is not AbstractProcess, argument <filename> is
not string or argument <executable> is not Executable.
"""
# Check parameter types
if not isinstance(processes, list) and not isinstance(processes,
AbstractProcess):
raise TypeError(f"Argument <processes> must be AbstractProcess"
f" or list of AbstractProcess, but got"
f" {processes}.")
if not isinstance(filename, str):
raise TypeError(f"Argument <filename> must be string"
f" but got {filename}.")
if executable is not None and not isinstance(executable, Executable):
raise TypeError(f"Argument <executable> must be Executable"
f" but got {executable}.")

# Create object which is stored
obj = SerializationObject(processes, executable)

# Add default file extension if no extension is present
if "." not in filename:
filename = filename + ".pickle"

# Store object at <filename>
with open(filename, 'wb') as f:
pickle.dump(obj, f)


def load(filename: str) -> ty.Tuple[ty.Union[AbstractProcess,
ty.List[AbstractProcess]],
ty.Union[None, Executable]]:
"""Loads a process or list of processes with an (optional)
corresponding executable from file <filename>.
Parameters
----------
filename: str
The path + name of the file. If no file extension is given,
'.pickle' will be added automatically.
Returns
-------
tuple
Returns a tuple of a process or list of processes and a executable or
None.
Raises
------
OSError
If the input file does not exist or cannot be read.
TypeError
If argument <filename> is not a string.
AssertionError
If provided file is not compatible/contains unexpected data.
"""

# Check parameter types
if not isinstance(filename, str):
raise TypeError(f"Argument <filename> must be string"
f" but got {filename}.")

# Check if filename exists
if not os.path.isfile(filename):
raise OSError(f"File {filename} could not be found.")

# Load serialized object from <filename>
with open(filename, 'rb') as f:
obj = pickle.load(f) # noqa: S301 # nosec

# Check loaded object
if not isinstance(obj, SerializationObject):
raise AssertionError(f"Incompatible file {filename} was provided.")

# Return processes and executable
return obj.processes, obj.executable
5 changes: 5 additions & 0 deletions tests/lava/tutorials/test_tutorials.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,11 @@ def test_in_depth_10_custom_learning_rules(self):
"""Test tutorial sigma_delta_neurons."""
self._run_notebook("tutorial10_sigma_delta_neurons.ipynb")

@unittest.skipIf(system_name != "linux", "Tests work on linux")
def test_in_depth_11_serialization(self):
"""Test tutorial serialization."""
self._run_notebook("tutorial11_serialization.ipynb")


if __name__ == "__main__":
support.run_unittest(TestTutorials)
177 changes: 177 additions & 0 deletions tests/lava/utils/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import unittest
import numpy as np
import tempfile
from lava.proc.lif.process import LIF
from lava.proc.dense.process import Dense
from lava.magma.core.run_configs import Loihi2SimCfg
from lava.magma.core.run_conditions import RunSteps
from lava.utils.serialization import save, load
from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.model.sub.model import AbstractSubProcessModel
from lava.magma.core.decorator import implements
from lava.magma.core.process.variable import Var


# A minimal hierarchical process
class HP(AbstractProcess):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.lif_in_v = Var(shape=(2,))
self.lif_out_u = Var(shape=(3,))


# A minimal hierarchical PyProcModel implementing HP
@implements(proc=HP)
class PyProcModelHP(AbstractSubProcessModel):

def __init__(self, proc):
"""Builds sub Process structure of the Process."""

pre_size = 2
post_size = 3

weights = np.ones((post_size, pre_size))

self.lif_in = LIF(shape=(pre_size,), bias_mant=100, vth=120,
name="LIF_neuron input")
self.dense = Dense(weights=weights * 10, name="Dense")
self.lif_out = LIF(shape=(post_size,), bias_mant=0, vth=50000,
name="LIF_neuron output")

self.lif_in.s_out.connect(self.dense.s_in)
self.dense.a_out.connect(self.lif_out.a_in)

proc.vars.lif_in_v.alias(self.lif_in.v)
proc.vars.lif_out_u.alias(self.lif_out.u)


class TestSerialization(unittest.TestCase):
def test_save_input_validation(self):
"""Checks the input validation of save()."""

# Parameter processes needs to be AbstractProcess or list of
# AbstractProcess
with self.assertRaises(TypeError):
save(processes=None, filename="test")

# Parameter filename needs to be string
with self.assertRaises(TypeError):
save(processes=[], filename=1)

# Parameter executable needs to be Executable
with self.assertRaises(TypeError):
save(processes=[], filename="test", executable=1)

def test_load_input_validation(self):
"""Checks the input validation of load()."""

# Parameter filename needs to be string
with self.assertRaises(TypeError):
load(filename=1)

def test_save_load_processes(self):
"""Checks storing and loading processes."""

weights = np.ones((2, 3))

# Create some processes
dense = Dense(weights=weights, name="Dense")
lif_procs = []
for i in range(5):
lif_procs.append(LIF(shape=(1,), name="LIF" + str(i)))

# Store the processes in file test.pickle
with tempfile.TemporaryDirectory() as tmpdirname:
save(lif_procs + [dense], tmpdirname + "test")
dense = None

# Load the processes again from test.pickle
procs, _ = load(tmpdirname + "test.pickle")

dense = procs[-1]

# Check if the processes have the same parameters
self.assertTrue(np.all(dense.weights.get() == weights))
self.assertTrue(dense.name == "Dense")

for i in range(5):
self.assertTrue(isinstance(procs[i], LIF))
self.assertTrue(procs[i].name == "LIF" + str(i))

def test_save_load_executable(self):
"""Checks storing and loading of executable."""

# Create a process
lif = LIF(shape=(1,), name="ExecLIF")

# Create an executable
ex = lif.compile(run_cfg=Loihi2SimCfg())

# Store the executable in file test.pickle
with tempfile.TemporaryDirectory() as tmpdirname:
save([], tmpdirname + "test", executable=ex)

# Load the executable from test.pickle
p, executable = load(tmpdirname + "test.pickle")

# Check if the executable reflects the inital process
self.assertTrue(p == [])
loaded_lif = executable.process_list[0]
self.assertTrue(lif.name == loaded_lif.name)

def test_save_load_hierarchical_proc(self):
"""Checks saving, loading and execution of a workload using a
hierarchical process."""

num_steps = 5
output_lif_in_v = np.zeros(shape=(2, num_steps))
output_lif_out_u = np.zeros(shape=(3, num_steps))

# Create hierarchical process
proc = HP()

# Create executable
ex = proc.compile(run_cfg=Loihi2SimCfg())

# Store executable and run it
with tempfile.TemporaryDirectory() as tmpdirname:
save(proc, tmpdirname + "test", ex)

proc.create_runtime(executable=ex)
try:
for i in range(num_steps):
proc.run(condition=RunSteps(num_steps=1))

output_lif_in_v[:, i] = proc.lif_in_v.get()
output_lif_out_u[:, i] = proc.lif_out_u.get()
finally:
proc.stop()

# Load executable again
proc_loaded, ex_loaded = load(tmpdirname + "test.pickle")

output_lif_in_v_loaded = np.zeros(shape=(2, num_steps))
output_lif_out_u_loaded = np.zeros(shape=(3, num_steps))

# Run the loaded executable
proc_loaded.create_runtime(executable=ex_loaded)
try:
for i in range(num_steps):
proc_loaded.run(condition=RunSteps(num_steps=1))

output_lif_in_v_loaded[:, i] = proc_loaded.lif_in_v.get()
output_lif_out_u_loaded[:, i] = proc_loaded.lif_out_u.get()
finally:
proc_loaded.stop()

# Compare results from inital run and run of loaded executable
self.assertTrue(np.all(output_lif_in_v == output_lif_in_v_loaded))
self.assertTrue(np.all(output_lif_out_u == output_lif_out_u_loaded))


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 87d6a5a

Please sign in to comment.