From dcf105f22a9c84f3f4a6f8309f432db089a53a5c Mon Sep 17 00:00:00 2001 From: Philipp Date: Fri, 21 Oct 2022 08:32:25 -0700 Subject: [PATCH 1/2] fix the stdp tests. minor update to structure --- src/lava/magma/core/model/py/connection.py | 101 ++++++++------------- 1 file changed, 40 insertions(+), 61 deletions(-) diff --git a/src/lava/magma/core/model/py/connection.py b/src/lava/magma/core/model/py/connection.py index 7ef2aa03f..e032608e0 100644 --- a/src/lava/magma/core/model/py/connection.py +++ b/src/lava/magma/core/model/py/connection.py @@ -60,23 +60,22 @@ def __init__(self, proc_params: dict) -> None: self.sign_mode = proc_params.get("sign_mode", SignMode.MIXED) - if self._learning_rule is not None: - # store shapes that useful throughout the lifetime of this PM - self._store_shapes() - # store impulses and taus in ndarrays with the right shapes - self._store_impulses_and_taus() - - # store active traces per dependency from learning_rule in ndarrays - # with the right shapes - self._build_active_traces_per_dependency() - # store active traces from learning_rule in ndarrays - # with the right shapes - self._build_active_traces() - # generate LearningRuleApplierBitApprox from ProductSeries - self._build_learning_rule_appliers() - - # initialize TraceRandoms and ConnVarRandom - self._init_randoms() + # store shapes that useful throughout the lifetime of this PM + self._store_shapes() + # store impulses and taus in ndarrays with the right shapes + self._store_impulses_and_taus() + + # store active traces per dependency from learning_rule in ndarrays + # with the right shapes + self._build_active_traces_per_dependency() + # store active traces from learning_rule in ndarrays + # with the right shapes + self._build_active_traces() + # generate LearningRuleApplierBitApprox from ProductSeries + self._build_learning_rule_appliers() + + # initialize TraceRandoms and ConnVarRandom + self._init_randoms() def _store_shapes(self) -> None: """Build and store several shapes that are needed in several @@ -287,6 +286,26 @@ def _within_epoch_time_step(self) -> int: return within_epoch_ts + def run_spk(self, s_in) -> None: + """ + Overrides the Connection Model run_spk function to + receive and update y2 and y3 traces. + """ + self._record_pre_spike_times(s_in) + + s_in_bap = self.s_in_bap.recv().astype(bool) + y2 = self.s_in_y2.recv() + y3 = self.s_in_y3.recv() + + self._record_post_spike_times(s_in_bap) + + y_traces = self._y_traces + y_traces[1, :] = y2 + y_traces[2, :] = y3 + self._set_y_traces(y_traces) + + self._update_trace_randoms() + def lrn_guard(self) -> bool: return self.time_step % self._learning_rule.t_epoch == 0 @@ -296,6 +315,10 @@ def run_lrn(self) -> None: self._update_traces() self._reset_dependencies_and_spike_times() + @abstractmethod + def _record_pre_spike_times(self, s_in: np.ndarray) -> None: + pass + @abstractmethod def _record_post_spike_times(self, s_in_bap: np.ndarray) -> None: pass @@ -390,13 +413,6 @@ class PlasticConnectionModelBitApproximate(PlasticConnection): tag_2: np.ndarray = LavaPyType(np.ndarray, int, precision=6) tag_1: np.ndarray = LavaPyType(np.ndarray, int, precision=8) - # def __init__(self, proc_params: dict) -> None: - # - # # Flag to determine whether weights have already been scaled. - # self.weights_set = False - # - # super().__init__(proc_params) - def _store_impulses_and_taus(self) -> None: """Build and store integer ndarrays representing x and y impulses and taus.""" @@ -1034,24 +1050,6 @@ def _update_traces(self) -> None: ) ) - def run_spk(self, s_in) -> None: - """ - Overrides the Connection Model run_spk function to - receive and update y2 and y3 traces. - """ - self._record_pre_spike_times(s_in) - - s_in_bap = self.s_in_bap.recv().astype(bool) - y2 = self.s_in_y2.recv() - y3 = self.s_in_y3.recv() - - self._record_post_spike_times(s_in_bap) - - y_traces = self._y_traces - y_traces[1, :] = y2 - y_traces[2, :] = y3 - self._set_y_traces(y_traces) - class PlasticConnectionModelFloat(PlasticConnection): """Floating-point implementation of the Connection Process @@ -1515,22 +1513,3 @@ def _update_traces(self) -> None: self._y_taus[:, np.newaxis], ) ) - - def run_spk(self, s_in) -> None: - """ - TODO: Change this - Overrides the Connection Model run_spk function to - receive and update y2 and y3 traces. - """ - self._record_pre_spike_times(s_in) - - s_in_bap = self.s_in_bap.recv().astype(bool) - y2 = self.s_in_y2.recv() - y3 = self.s_in_y3.recv() - - self._record_post_spike_times(s_in_bap) - - y_traces = self._y_traces - y_traces[1, :] = y2 - y_traces[2, :] = y3 - self._set_y_traces(y_traces) From c15319849b796506fa76938ffe203cd8aaae8e59 Mon Sep 17 00:00:00 2001 From: Philipp Date: Fri, 21 Oct 2022 08:37:35 -0700 Subject: [PATCH 2/2] minor change --- src/lava/magma/core/model/py/connection.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lava/magma/core/model/py/connection.py b/src/lava/magma/core/model/py/connection.py index e032608e0..05a9b486f 100644 --- a/src/lava/magma/core/model/py/connection.py +++ b/src/lava/magma/core/model/py/connection.py @@ -50,6 +50,9 @@ class PlasticConnection: tag_2 = None tag_1 = None + weights = None + time_step = None + def __init__(self, proc_params: dict) -> None: super().__init__(proc_params)