Skip to content

Commit

Permalink
Merge pull request lava-nc#3 from weidel-p/bala/dev/learning_three_fa…
Browse files Browse the repository at this point in the history
…ctor

Bala/dev/learning three factor
  • Loading branch information
bala-git9 committed Oct 21, 2022
2 parents e97fbb0 + c153198 commit 9a4583d
Showing 1 changed file with 41 additions and 59 deletions.
100 changes: 41 additions & 59 deletions src/lava/magma/core/model/py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class PlasticConnection:
tag_2 = None
tag_1 = None

weights = None
time_step = None

def __init__(self, proc_params: dict) -> None:
super().__init__(proc_params)

Expand All @@ -60,23 +63,22 @@ def __init__(self, proc_params: dict) -> None:

self.sign_mode = proc_params.get("sign_mode", SignMode.MIXED)

if self._learning_rule is not None:
# store shapes that useful throughout the lifetime of this PM
self._store_shapes()
# store impulses and taus in ndarrays with the right shapes
self._store_impulses_and_taus()
# store shapes that useful throughout the lifetime of this PM
self._store_shapes()
# store impulses and taus in ndarrays with the right shapes
self._store_impulses_and_taus()

# store active traces per dependency from learning_rule in ndarrays
# with the right shapes
self._build_active_traces_per_dependency()
# store active traces from learning_rule in ndarrays
# with the right shapes
self._build_active_traces()
# generate LearningRuleApplierBitApprox from ProductSeries
self._build_learning_rule_appliers()
# store active traces per dependency from learning_rule in ndarrays
# with the right shapes
self._build_active_traces_per_dependency()
# store active traces from learning_rule in ndarrays
# with the right shapes
self._build_active_traces()
# generate LearningRuleApplierBitApprox from ProductSeries
self._build_learning_rule_appliers()

# initialize TraceRandoms and ConnVarRandom
self._init_randoms()
# initialize TraceRandoms and ConnVarRandom
self._init_randoms()

def _store_shapes(self) -> None:
"""Build and store several shapes that are needed in several
Expand Down Expand Up @@ -287,6 +289,26 @@ def _within_epoch_time_step(self) -> int:

return within_epoch_ts

def run_spk(self, s_in) -> None:
"""
Overrides the Connection Model run_spk function to
receive and update y2 and y3 traces.
"""
self._record_pre_spike_times(s_in)

s_in_bap = self.s_in_bap.recv().astype(bool)
y2 = self.s_in_y2.recv()
y3 = self.s_in_y3.recv()

self._record_post_spike_times(s_in_bap)

y_traces = self._y_traces
y_traces[1, :] = y2
y_traces[2, :] = y3
self._set_y_traces(y_traces)

self._update_trace_randoms()

def lrn_guard(self) -> bool:
return self.time_step % self._learning_rule.t_epoch == 0

Expand All @@ -296,6 +318,10 @@ def run_lrn(self) -> None:
self._update_traces()
self._reset_dependencies_and_spike_times()

@abstractmethod
def _record_pre_spike_times(self, s_in: np.ndarray) -> None:
pass

@abstractmethod
def _record_post_spike_times(self, s_in_bap: np.ndarray) -> None:
pass
Expand Down Expand Up @@ -390,13 +416,6 @@ class PlasticConnectionModelBitApproximate(PlasticConnection):
tag_2: np.ndarray = LavaPyType(np.ndarray, int, precision=6)
tag_1: np.ndarray = LavaPyType(np.ndarray, int, precision=8)

# def __init__(self, proc_params: dict) -> None:
#
# # Flag to determine whether weights have already been scaled.
# self.weights_set = False
#
# super().__init__(proc_params)

def _store_impulses_and_taus(self) -> None:
"""Build and store integer ndarrays representing x and y
impulses and taus."""
Expand Down Expand Up @@ -1034,24 +1053,6 @@ def _update_traces(self) -> None:
)
)

def run_spk(self, s_in) -> None:
"""
Overrides the Connection Model run_spk function to
receive and update y2 and y3 traces.
"""
self._record_pre_spike_times(s_in)

s_in_bap = self.s_in_bap.recv().astype(bool)
y2 = self.s_in_y2.recv()
y3 = self.s_in_y3.recv()

self._record_post_spike_times(s_in_bap)

y_traces = self._y_traces
y_traces[1, :] = y2
y_traces[2, :] = y3
self._set_y_traces(y_traces)


class PlasticConnectionModelFloat(PlasticConnection):
"""Floating-point implementation of the Connection Process
Expand Down Expand Up @@ -1515,22 +1516,3 @@ def _update_traces(self) -> None:
self._y_taus[:, np.newaxis],
)
)

def run_spk(self, s_in) -> None:
"""
TODO: Change this
Overrides the Connection Model run_spk function to
receive and update y2 and y3 traces.
"""
self._record_pre_spike_times(s_in)

s_in_bap = self.s_in_bap.recv().astype(bool)
y2 = self.s_in_y2.recv()
y3 = self.s_in_y3.recv()

self._record_post_spike_times(s_in_bap)

y_traces = self._y_traces
y_traces[1, :] = y2
y_traces[2, :] = y3
self._set_y_traces(y_traces)

0 comments on commit 9a4583d

Please sign in to comment.