diff --git a/src/lava/magma/compiler/compiler.py b/src/lava/magma/compiler/compiler.py index 1a3beec26..9937fe801 100644 --- a/src/lava/magma/compiler/compiler.py +++ b/src/lava/magma/compiler/compiler.py @@ -128,6 +128,8 @@ def compile( # ProcGroups. proc_group_digraph = ProcGroupDiGraphs(process, run_cfg) proc_groups: ty.List[ProcGroup] = proc_group_digraph.get_proc_groups() + # Get a flattened list of all AbstractProcesses + process_list = list(itertools.chain.from_iterable(proc_groups)) channel_map = ChannelMap.from_proc_groups(proc_groups) proc_builders, channel_map = self._compile_proc_groups( proc_groups, channel_map @@ -161,6 +163,7 @@ def compile( # Package all Builders and NodeConfigs into an Executable. executable = Executable( + process_list, proc_builders, channel_builders, node_configs, diff --git a/src/lava/magma/compiler/executable.py b/src/lava/magma/compiler/executable.py index e1d325615..42d290357 100644 --- a/src/lava/magma/compiler/executable.py +++ b/src/lava/magma/compiler/executable.py @@ -33,6 +33,7 @@ class Executable: # py_builders: ty.Dict[AbstractProcess, NcProcessBuilder] # c_builders: ty.Dict[AbstractProcess, CProcessBuilder] # nc_builders: ty.Dict[AbstractProcess, PyProcessBuilder] + process_list: ty.List[AbstractProcess] # All leaf processes, flat list. proc_builders: ty.Dict[AbstractProcess, 'AbstractProcessBuilder'] channel_builders: ty.List[ChannelBuilderMp] node_configs: ty.List[NodeConfig] @@ -43,5 +44,5 @@ class Executable: ty.Iterable[AbstractChannelBuilder]] = None def assign_runtime_to_all_processes(self, runtime): - for p in self.proc_builders.keys(): + for p in self.process_list: p.runtime = runtime diff --git a/tests/lava/magma/runtime/test_runtime.py b/tests/lava/magma/runtime/test_runtime.py index a70563638..6be322f45 100644 --- a/tests/lava/magma/runtime/test_runtime.py +++ b/tests/lava/magma/runtime/test_runtime.py @@ -27,7 +27,8 @@ def test_runtime_creation(self): def test_executable_node_config_assertion(self): """Tests runtime constructions with expected constraints""" - exe: Executable = Executable(proc_builders={}, + exe: Executable = Executable(process_list=[], + proc_builders={}, channel_builders=[], node_configs=[], sync_domains=[]) @@ -45,7 +46,8 @@ def test_executable_node_config_assertion(self): f"Expected type {expected_type} doesn't match {(type(runtime2))}") runtime2.stop() - exe1: Executable = Executable(proc_builders={}, + exe1: Executable = Executable(process_list=[], + proc_builders={}, channel_builders=[], node_configs=[], sync_domains=[])