Skip to content

Commit

Permalink
Added iterable callback function
Browse files Browse the repository at this point in the history
Signed-off-by: bamsumit <[email protected]>
  • Loading branch information
bamsumit committed Jun 27, 2023
1 parent b0642e9 commit 27c81c2
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/lava/magma/core/callback_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from abc import ABC, abstractmethod
from typing import Iterable
try:
from nxcore.arch.base.nxboard import NxBoard
except ImportError:
Expand Down Expand Up @@ -53,3 +54,23 @@ def post_run_callback(self,
board: NxBoard = None,
var_id_to_var_model_map: dict = None):
pass


class IterableCallBack(NxSdkCallbackFx):
"""NxSDK callback function to execute iterable of function pointers
as pre and post run."""

def __init__(self,
pre_run_fxs: Iterable = [],
post_run_fxs: Iterable = []) -> None:
super().__init__()
self.pre_run_fxs = pre_run_fxs
self.post_run_fxs = post_run_fxs

def pre_run_callback(self, board: NxBoard, **_) -> None:
for fx in self.pre_run_fxs:
fx(board)

def post_run_callback(self, board: NxBoard, **_) -> None:
for fx in self.post_run_fxs:
fx(board)

0 comments on commit 27c81c2

Please sign in to comment.