Skip to content

Commit

Permalink
Iterator callback function (lava-nc#726)
Browse files Browse the repository at this point in the history
* update refport unittest to always wait when it writes to port for consistent behavior

Signed-off-by: bamsumit <[email protected]>

* Removed pyproject changes

Signed-off-by: bamsumit <[email protected]>

* Fix to convolution tests. Fixed imcompatible mnist_pretrained for old python versions.

Signed-off-by: bamsumit <[email protected]>

* Missing moudle parent fix

Signed-off-by: bamsumit <[email protected]>

* Added ConvVarModel

* Added iterable callback function

Signed-off-by: bamsumit <[email protected]>

* Fix codacy issues in callback_fx.py

* Fix linting in callback_fx.py

* Fix codacy sig issue in callback_fx.py

---------

Signed-off-by: bamsumit <[email protected]>
Co-authored-by: Joyesh Mishra <[email protected]>
Co-authored-by: Marcus G K Williams <[email protected]>
  • Loading branch information
3 people authored Jun 28, 2023
1 parent bd5ea94 commit 4d455b3
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/lava/magma/core/callback_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from abc import ABC, abstractmethod
from typing import Iterable
try:
from nxcore.arch.base.nxboard import NxBoard
except ImportError:
Expand Down Expand Up @@ -53,3 +54,33 @@ def post_run_callback(self,
board: NxBoard = None,
var_id_to_var_model_map: dict = None):
pass


class IterableCallBack(NxSdkCallbackFx):
"""NxSDK callback function to execute iterable of function pointers
as pre and post run."""

def __init__(self,
pre_run_fxs: Iterable = None,
post_run_fxs: Iterable = None) -> None:
super().__init__()
if pre_run_fxs is None:
pre_run_fxs = []
if post_run_fxs is None:
post_run_fxs = []
self.pre_run_fxs = pre_run_fxs
self.post_run_fxs = post_run_fxs

def pre_run_callback(self,
board: NxBoard = None,
_var_id_to_var_model_map: dict = None
) -> None:
for fx in self.pre_run_fxs:
fx(board)

def post_run_callback(self,
board: NxBoard = None,
_var_id_to_var_model_map: dict = None
) -> None:
for fx in self.post_run_fxs:
fx(board)

0 comments on commit 4d455b3

Please sign in to comment.