Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
elvinhajizada committed Jul 4, 2023
2 parents 3228094 + 63e4834 commit 3e4a4fc
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/lava/magma/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def compile(
# ProcGroups.
proc_group_digraph = ProcGroupDiGraphs(process, run_cfg)
proc_groups: ty.List[ProcGroup] = proc_group_digraph.get_proc_groups()
# Get a flattened list of all AbstractProcesses
process_list = list(itertools.chain.from_iterable(proc_groups))
channel_map = ChannelMap.from_proc_groups(proc_groups)
proc_builders, channel_map = self._compile_proc_groups(
proc_groups, channel_map
Expand Down Expand Up @@ -161,6 +163,7 @@ def compile(

# Package all Builders and NodeConfigs into an Executable.
executable = Executable(
process_list,
proc_builders,
channel_builders,
node_configs,
Expand Down
3 changes: 2 additions & 1 deletion src/lava/magma/compiler/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Executable:
# py_builders: ty.Dict[AbstractProcess, NcProcessBuilder]
# c_builders: ty.Dict[AbstractProcess, CProcessBuilder]
# nc_builders: ty.Dict[AbstractProcess, PyProcessBuilder]
process_list: ty.List[AbstractProcess] # All leaf processes, flat list.
proc_builders: ty.Dict[AbstractProcess, 'AbstractProcessBuilder']
channel_builders: ty.List[ChannelBuilderMp]
node_configs: ty.List[NodeConfig]
Expand All @@ -43,5 +44,5 @@ class Executable:
ty.Iterable[AbstractChannelBuilder]] = None

def assign_runtime_to_all_processes(self, runtime):
for p in self.proc_builders.keys():
for p in self.process_list:
p.runtime = runtime
31 changes: 31 additions & 0 deletions src/lava/magma/core/callback_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from abc import ABC, abstractmethod
from typing import Iterable
try:
from nxcore.arch.base.nxboard import NxBoard
except ImportError:
Expand Down Expand Up @@ -53,3 +54,33 @@ def post_run_callback(self,
board: NxBoard = None,
var_id_to_var_model_map: dict = None):
pass


class IterableCallBack(NxSdkCallbackFx):
"""NxSDK callback function to execute iterable of function pointers
as pre and post run."""

def __init__(self,
pre_run_fxs: Iterable = None,
post_run_fxs: Iterable = None) -> None:
super().__init__()
if pre_run_fxs is None:
pre_run_fxs = []
if post_run_fxs is None:
post_run_fxs = []
self.pre_run_fxs = pre_run_fxs
self.post_run_fxs = post_run_fxs

def pre_run_callback(self,
board: NxBoard = None,
_var_id_to_var_model_map: dict = None
) -> None:
for fx in self.pre_run_fxs:
fx(board)

def post_run_callback(self,
board: NxBoard = None,
_var_id_to_var_model_map: dict = None
) -> None:
for fx in self.post_run_fxs:
fx(board)
6 changes: 4 additions & 2 deletions tests/lava/magma/runtime/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def test_runtime_creation(self):

def test_executable_node_config_assertion(self):
"""Tests runtime constructions with expected constraints"""
exe: Executable = Executable(proc_builders={},
exe: Executable = Executable(process_list=[],
proc_builders={},
channel_builders=[],
node_configs=[],
sync_domains=[])
Expand All @@ -45,7 +46,8 @@ def test_executable_node_config_assertion(self):
f"Expected type {expected_type} doesn't match {(type(runtime2))}")
runtime2.stop()

exe1: Executable = Executable(proc_builders={},
exe1: Executable = Executable(process_list=[],
proc_builders={},
channel_builders=[],
node_configs=[],
sync_domains=[])
Expand Down

0 comments on commit 3e4a4fc

Please sign in to comment.