Skip to content

Commit

Permalink
Enable exception proc_map, working with dataclasses, etc. (#372)
Browse files Browse the repository at this point in the history
* Enable exception proc_map, working with dataclasses, etc.

Signed-off-by: GaboFGuerra <[email protected]>

* Fix lint issue and check if run_cfg has exception model map

Signed-off-by: GaboFGuerra <[email protected]>
Co-authored-by: Marcus G K Williams <[email protected]>
Co-authored-by: Joyesh Mishra <[email protected]>
  • Loading branch information
3 people authored and Marcus G K Williams committed Sep 28, 2022
1 parent 65188e1 commit c294760
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
15 changes: 10 additions & 5 deletions src/lava/magma/compiler/compiler_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def _find_proc_models(proc: AbstractProcess) \
if not proc_module.__name__ == "__main__":
# Get the parent module.
module_spec = importlib.util.find_spec(proc_module.__name__)
if module_spec.parent:
if module_spec.parent != '':
parent_module = importlib.import_module(module_spec.parent)

# Get all the modules inside the parent (namespace) module.
Expand Down Expand Up @@ -992,16 +992,21 @@ def _map_proc_to_model(procs: ty.List[AbstractProcess],
proc_map = OrderedDict()
for proc in procs:
# Select a specific ProcessModel
models_cls = ProcGroupDiGraphs._find_proc_models(proc=proc)
model_cls = \
ProcGroupDiGraphs._select_proc_models(
proc, models_cls, run_cfg)
if hasattr(run_cfg, "exception_proc_model_map") and \
proc in run_cfg.exception_proc_model_map:
model_cls = run_cfg.exception_proc_model_map[proc]
else:
models_cls = ProcGroupDiGraphs._find_proc_models(proc=proc)
model_cls = ProcGroupDiGraphs._select_proc_models(proc,
models_cls,
run_cfg)
if issubclass(model_cls, AbstractSubProcessModel):
# Recursively substitute SubProcModel by sub processes
sub_map = ProcGroupDiGraphs._expand_sub_proc_model(model_cls,
proc,
run_cfg)
proc_map.update(sub_map)
proc._model_class = model_cls
else:
# Just map current Process to selected ProcessModel
proc_map[proc] = model_cls
Expand Down
9 changes: 8 additions & 1 deletion src/lava/magma/core/model/sub/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from lava.magma.core.model.model import AbstractProcessModel
from lava.magma.core.process.process import AbstractProcess
from dataclasses import is_dataclass, fields


class AbstractSubProcessModel(AbstractProcessModel):
Expand Down Expand Up @@ -58,6 +59,12 @@ def find_sub_procs(self) -> ty.Dict[str, AbstractProcess]:
procs = OrderedDict()
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, AbstractProcess):
if isinstance(attr, AbstractProcess) and \
attr is not self.implements_process:
procs[attr_name] = attr
if is_dataclass(attr):
for data in fields(attr):
sub_attr = getattr(attr, data.name)
if isinstance(sub_attr, AbstractProcess):
procs[type(sub_attr).__name__] = sub_attr
return procs

0 comments on commit c294760

Please sign in to comment.