From 60a2d94268ecd598f329a2b59e66cb68e7614186 Mon Sep 17 00:00:00 2001 From: GaboFGuerra Date: Mon, 26 Sep 2022 18:02:21 +0200 Subject: [PATCH 1/2] Enable exception proc_map, working with dataclasses, etc. Signed-off-by: GaboFGuerra --- src/lava/magma/compiler/compiler_graphs.py | 14 +++++++++----- src/lava/magma/core/model/sub/model.py | 9 ++++++++- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/lava/magma/compiler/compiler_graphs.py b/src/lava/magma/compiler/compiler_graphs.py index 4b9b03b1f..45b44996f 100644 --- a/src/lava/magma/compiler/compiler_graphs.py +++ b/src/lava/magma/compiler/compiler_graphs.py @@ -815,7 +815,7 @@ def _find_proc_models(proc: AbstractProcess) \ if not proc_module.__name__ == "__main__": # Get the parent module. module_spec = importlib.util.find_spec(proc_module.__name__) - if module_spec.parent: + if module_spec.parent != '': parent_module = importlib.import_module(module_spec.parent) # Get all the modules inside the parent (namespace) module. @@ -992,16 +992,20 @@ def _map_proc_to_model(procs: ty.List[AbstractProcess], proc_map = OrderedDict() for proc in procs: # Select a specific ProcessModel - models_cls = ProcGroupDiGraphs._find_proc_models(proc=proc) - model_cls = \ - ProcGroupDiGraphs._select_proc_models( - proc, models_cls, run_cfg) + if proc in run_cfg.exception_proc_model_map: + model_cls = run_cfg.exception_proc_model_map[proc] + else: + models_cls = ProcGroupDiGraphs._find_proc_models(proc=proc) + model_cls = ProcGroupDiGraphs._select_proc_models(proc, + models_cls, + run_cfg) if issubclass(model_cls, AbstractSubProcessModel): # Recursively substitute SubProcModel by sub processes sub_map = ProcGroupDiGraphs._expand_sub_proc_model(model_cls, proc, run_cfg) proc_map.update(sub_map) + proc._model_class = model_cls else: # Just map current Process to selected ProcessModel proc_map[proc] = model_cls diff --git a/src/lava/magma/core/model/sub/model.py b/src/lava/magma/core/model/sub/model.py index 110344609..1776984e5 100644 --- a/src/lava/magma/core/model/sub/model.py +++ b/src/lava/magma/core/model/sub/model.py @@ -8,6 +8,7 @@ from lava.magma.core.model.model import AbstractProcessModel from lava.magma.core.process.process import AbstractProcess +from dataclasses import is_dataclass, fields class AbstractSubProcessModel(AbstractProcessModel): @@ -58,6 +59,12 @@ def find_sub_procs(self) -> ty.Dict[str, AbstractProcess]: procs = OrderedDict() for attr_name in dir(self): attr = getattr(self, attr_name) - if isinstance(attr, AbstractProcess): + if isinstance(attr, AbstractProcess) and \ + not attr is self.implements_process: procs[attr_name] = attr + if is_dataclass(attr): + for data in fields(attr): + sub_attr = getattr(attr, data.name) + if isinstance(sub_attr, AbstractProcess): + procs[type(sub_attr).__name__] = sub_attr return procs From e5cc75fd5b9b0ad6200a068dfed30d9685803d7c Mon Sep 17 00:00:00 2001 From: Joyesh Mishra Date: Tue, 27 Sep 2022 12:45:03 -0700 Subject: [PATCH 2/2] Fix lint issue and check if run_cfg has exception model map --- src/lava/magma/compiler/compiler_graphs.py | 3 ++- src/lava/magma/core/model/sub/model.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lava/magma/compiler/compiler_graphs.py b/src/lava/magma/compiler/compiler_graphs.py index 45b44996f..c9148bf4d 100644 --- a/src/lava/magma/compiler/compiler_graphs.py +++ b/src/lava/magma/compiler/compiler_graphs.py @@ -992,7 +992,8 @@ def _map_proc_to_model(procs: ty.List[AbstractProcess], proc_map = OrderedDict() for proc in procs: # Select a specific ProcessModel - if proc in run_cfg.exception_proc_model_map: + if hasattr(run_cfg, "exception_proc_model_map") and \ + proc in run_cfg.exception_proc_model_map: model_cls = run_cfg.exception_proc_model_map[proc] else: models_cls = ProcGroupDiGraphs._find_proc_models(proc=proc) diff --git a/src/lava/magma/core/model/sub/model.py b/src/lava/magma/core/model/sub/model.py index 1776984e5..3edadd04b 100644 --- a/src/lava/magma/core/model/sub/model.py +++ b/src/lava/magma/core/model/sub/model.py @@ -60,7 +60,7 @@ def find_sub_procs(self) -> ty.Dict[str, AbstractProcess]: for attr_name in dir(self): attr = getattr(self, attr_name) if isinstance(attr, AbstractProcess) and \ - not attr is self.implements_process: + attr is not self.implements_process: procs[attr_name] = attr if is_dataclass(attr): for data in fields(attr):