Skip to content

Commit

Permalink
Merge pull request #2443 from JanVogelsang/stdp-k-value-error
Browse files Browse the repository at this point in the history
Fix STDP k-value error for edge case
  • Loading branch information
abigailm authored Apr 5, 2023
2 parents f835afd + d77aa83 commit b3ba799
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 80 deletions.
5 changes: 3 additions & 2 deletions nestkernel/archiving_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,13 @@ nest::ArchivingNode::set_spiketime( Time const& t_sp, double offset )
// - its access counter indicates it has been read out by all connected
// STDP synapses, and
// - there is another, later spike, that is strictly more than
// (max_delay_ + eps) away from the new spike (at t_sp_ms)
// (min_global_delay + max_local_delay + eps) away from the new spike (at t_sp_ms)
while ( history_.size() > 1 )
{
const double next_t_sp = history_[ 1 ].t_;
if ( history_.front().access_counter_ >= n_incoming_
and t_sp_ms - next_t_sp > max_delay_ + kernel().connection_manager.get_stdp_eps() )
and t_sp_ms - next_t_sp
> max_delay_ + kernel().connection_manager.get_min_delay() + kernel().connection_manager.get_stdp_eps() )
{
history_.pop_front();
}
Expand Down
156 changes: 80 additions & 76 deletions testsuite/pytests/test_stdp_synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
try:
import matplotlib as mpl
import matplotlib.pyplot as plt

DEBUG_PLOTS = True
except Exception:
DEBUG_PLOTS = False
Expand All @@ -42,11 +43,11 @@ class TestSTDPSynapse:
"""

def init_params(self):
self.resolution = 0.1 # [ms]
self.simulation_duration = 1E3 # [ms]
self.resolution = 0.1 # [ms]
self.simulation_duration = 1E3 # [ms]
self.synapse_model = "stdp_synapse"
self.presynaptic_firing_rate = 20. # [ms^-1]
self.postsynaptic_firing_rate = 20. # [ms^-1]
self.presynaptic_firing_rate = 20. # [ms^-1]
self.postsynaptic_firing_rate = 20. # [ms^-1]
self.tau_pre = 16.8
self.tau_post = 33.7
self.init_weight = .5
Expand All @@ -65,44 +66,37 @@ def init_params(self):
}
self.neuron_parameters = {
"tau_minus": self.tau_post,
"t_ref": 1.0
}

# While the random sequences, fairly long, would supposedly
# reveal small differences in the weight change between NEST
# and ours, some low-probability events (say, coinciding
# spikes) can well not have occurred. To generate and
# test every possible combination of pre/post order, we
# append some hardcoded spike sequences:
# pre: 1 5 6 7 9 11 12 13
# post: 2 3 4 8 9 10 12
self.hardcoded_pre_times = np.array([1, 5, 6, 7, 9, 11, 12, 13], dtype=float)
self.hardcoded_post_times = np.array([2, 3, 4, 8, 9, 10, 12], dtype=float)
self.hardcoded_trains_length = 2. + max(np.amax(self.hardcoded_pre_times), np.amax(self.hardcoded_post_times))
# append some hardcoded spike sequences
self.hardcoded_pre_times = np.array(
[1, 5, 6, 7, 9, 11, 12, 13, 14.5, 16.1], dtype=float)
self.hardcoded_post_times = np.array(
[2, 3, 4, 8, 9, 10, 12, 13.2, 15.1, 16.4], dtype=float)
self.hardcoded_trains_length = 2. + max(
np.amax(self.hardcoded_pre_times),
np.amax(self.hardcoded_post_times))

def do_nest_simulation_and_compare_to_reproduced_weight(self, fname_snip):
pre_spikes, post_spikes, t_weight_by_nest, weight_by_nest = self.do_the_nest_simulation()
if DEBUG_PLOTS:
self.plot_weight_evolution(pre_spikes, post_spikes,
t_weight_by_nest,
weight_by_nest,
fname_snip=fname_snip,
self.plot_weight_evolution(pre_spikes, post_spikes, t_weight_by_nest, weight_by_nest, fname_snip=fname_snip,
title_snip=self.nest_neuron_model + " (NEST)")

t_weight_reproduced_independently, weight_reproduced_independently = self.reproduce_weight_drift(
pre_spikes, post_spikes,
self.init_weight,
fname_snip=fname_snip)
t_weight_reproduced_independently, weight_reproduced_independently = \
self.reproduce_weight_drift(pre_spikes, post_spikes, self.init_weight, fname_snip=fname_snip)

# ``weight_by_nest`` contains only weight values at pre spike times, ``weight_reproduced_independently``
# contains the weight at pre *and* post times: check that weights are equal only for pre spike times
# contains the weight at pre *and* post times: check that weights are equal for pre spike times
assert len(weight_by_nest) > 0
for idx_pre_spike_nest, t_pre_spike_nest in enumerate(t_weight_by_nest):
idx_pre_spike_reproduced_independently = \
np.argmin((t_pre_spike_nest - t_weight_reproduced_independently)**2)
np.testing.assert_allclose(t_pre_spike_nest,
t_weight_reproduced_independently[idx_pre_spike_reproduced_independently])
np.testing.assert_allclose(weight_by_nest[idx_pre_spike_nest],
weight_reproduced_independently[idx_pre_spike_reproduced_independently])
np.testing.assert_allclose(t_weight_by_nest, t_weight_reproduced_independently)
np.testing.assert_allclose(weight_by_nest, weight_reproduced_independently)

def do_the_nest_simulation(self):
"""
Expand Down Expand Up @@ -134,21 +128,25 @@ def do_the_nest_simulation(self):
spike_senders = nest.Create(
"spike_generator",
2,
params=({"spike_times": self.hardcoded_pre_times
+ self.simulation_duration - self.hardcoded_trains_length},
{"spike_times": self.hardcoded_post_times
+ self.simulation_duration - self.hardcoded_trains_length})
params=({"spike_times": self.hardcoded_pre_times +
self.simulation_duration - self.hardcoded_trains_length},
{"spike_times": self.hardcoded_post_times +
self.simulation_duration - self.hardcoded_trains_length})
)
pre_spike_generator = spike_senders[0]
post_spike_generator = spike_senders[1]

# The recorder is to save the randomly generated spike trains.
spike_recorder = nest.Create("spike_recorder")

nest.Connect(presynaptic_generator + pre_spike_generator, presynaptic_neuron,
syn_spec={"synapse_model": "static_synapse", "weight": 9999.})
nest.Connect(postsynaptic_generator + post_spike_generator, postsynaptic_neuron,
syn_spec={"synapse_model": "static_synapse", "weight": 9999.})
nest.Connect(presynaptic_generator + pre_spike_generator,
presynaptic_neuron,
syn_spec={"synapse_model": "static_synapse",
"weight": 9999.})
nest.Connect(postsynaptic_generator + post_spike_generator,
postsynaptic_neuron,
syn_spec={"synapse_model": "static_synapse",
"weight": 9999.})
nest.Connect(presynaptic_neuron + postsynaptic_neuron, spike_recorder,
syn_spec={"synapse_model": "static_synapse"})
# The synapse of interest itself
Expand All @@ -169,19 +167,20 @@ def do_the_nest_simulation(self):

def reproduce_weight_drift(self, pre_spikes, post_spikes, initial_weight, fname_snip=""):
"""Independent, self-contained model of STDP"""

def facilitate(w, Kpre, Wmax_=1.):
norm_w = (w / self.synapse_parameters["Wmax"]) + (
self.synapse_parameters["lambda"] * pow(
1 - (w / self.synapse_parameters["Wmax"]), self.synapse_parameters["mu_plus"]) * Kpre)
self.synapse_parameters["lambda"] * pow(1 - (w / self.synapse_parameters["Wmax"]),
self.synapse_parameters["mu_plus"]) * Kpre)
if norm_w < 1.0:
return norm_w * self.synapse_parameters["Wmax"]
else:
return self.synapse_parameters["Wmax"]

def depress(w, Kpost):
norm_w = (w / self.synapse_parameters["Wmax"]) - (
self.synapse_parameters["alpha"] * self.synapse_parameters["lambda"] * pow(
w / self.synapse_parameters["Wmax"], self.synapse_parameters["mu_minus"]) * Kpost)
norm_w = (w / self.synapse_parameters["Wmax"]) - \
(self.synapse_parameters["alpha"] * self.synapse_parameters["lambda"] *
pow(w / self.synapse_parameters["Wmax"], self.synapse_parameters["mu_minus"]) * Kpost)
if norm_w > 0.0:
return norm_w * self.synapse_parameters["Wmax"]
else:
Expand Down Expand Up @@ -209,42 +208,47 @@ def Kpost_at_time(t, spikes, inclusive=True):
Kpost *= exp(-(t - t_curr) / self.tau_post)
return Kpost

eps = 1e-6
t = 0.
idx_next_pre_spike = 0
idx_next_post_spike = 0
t_last_pre_spike = -1
t_last_post_spike = -1
Kpre = 0.
weight = initial_weight

t_log = []
w_log = []
Kpre_log = []

# logging
t_log.append(t)
w_log.append(weight)
Kpre_log.append(Kpre)
t_log = list()
w_log = list()
Kpre_log = list()
pre_spike_times = list()

post_spikes_delayed = post_spikes + self.dendritic_delay

while t < self.simulation_duration:
idx_next_pre_spike = -1
if np.where((pre_spikes - t) > 0)[0].size > 0:
idx_next_pre_spike = np.where((pre_spikes - t) > 0)[0][0]
if idx_next_pre_spike >= pre_spikes.size:
t_next_pre_spike = -1
else:
t_next_pre_spike = pre_spikes[idx_next_pre_spike]

idx_next_post_spike = -1
if np.where((post_spikes_delayed - t) > 0)[0].size > 0:
idx_next_post_spike = np.where((post_spikes_delayed - t) > 0)[0][0]
if idx_next_post_spike >= post_spikes.size:
t_next_post_spike = -1
else:
t_next_post_spike = post_spikes_delayed[idx_next_post_spike]

if idx_next_pre_spike >= 0 and idx_next_post_spike >= 0 and t_next_post_spike < t_next_pre_spike:
handle_post_spike = True
if t_next_post_spike >= 0 and (t_next_post_spike + eps < t_next_pre_spike or t_next_pre_spike < 0):
handle_pre_spike = False
elif idx_next_pre_spike >= 0 and idx_next_post_spike >= 0 and t_next_post_spike > t_next_pre_spike:
handle_post_spike = False
handle_post_spike = True
idx_next_post_spike += 1
elif t_next_pre_spike >= 0 and (t_next_post_spike > t_next_pre_spike + eps or t_next_post_spike < 0):
handle_pre_spike = True
handle_post_spike = False
idx_next_pre_spike += 1
else:
# simultaneous spikes (both true) or no more spikes to process (both false)
handle_post_spike = idx_next_post_spike >= 0
handle_pre_spike = idx_next_pre_spike >= 0
handle_pre_spike = t_next_pre_spike >= 0
handle_post_spike = t_next_post_spike >= 0
idx_next_pre_spike += 1
idx_next_post_spike += 1

# integrate to min(t_next_pre_spike, t_next_post_spike)
t_next = t
Expand All @@ -257,40 +261,40 @@ def Kpost_at_time(t, spikes, inclusive=True):
# no more spikes to process
t_next = self.simulation_duration

'''# max timestep
t_next_ = min(t_next, t + 1E-3)
if t_next != t_next_:
t_next = t_next_
handle_pre_spike = False
handle_post_spike = False'''

h = t_next - t
Kpre *= exp(-h / self.tau_pre)
t = t_next

if handle_post_spike:
# Kpost += 1. <-- not necessary, will call Kpost_at_time(t) later to compute Kpost for any value t
weight = facilitate(weight, Kpre)
if not handle_pre_spike or abs(t_next_post_spike - t_last_post_spike) > eps:
if abs(t_next_post_spike - t_last_pre_spike) > eps:
weight = facilitate(weight, Kpre)

if handle_pre_spike:
Kpre += 1.
_Kpost = Kpost_at_time(t - self.dendritic_delay, post_spikes, inclusive=False)
weight = depress(weight, _Kpost)
if not handle_post_spike or abs(t_next_pre_spike - t_last_pre_spike) > eps:
if abs(t_next_pre_spike - t_last_post_spike) > eps:
_Kpost = Kpost_at_time(t - self.dendritic_delay, post_spikes, inclusive=False)
weight = depress(weight, _Kpost)
t_last_pre_spike = t_next_pre_spike
pre_spike_times.append(t)

if handle_post_spike:
t_last_post_spike = t_next_post_spike

# logging
t_log.append(t)
w_log.append(weight)
Kpre_log.append(Kpre)
w_log.append(weight)
t_log.append(t)

Kpost_log = [Kpost_at_time(t - self.dendritic_delay, post_spikes) for t in t_log]
if DEBUG_PLOTS:
self.plot_weight_evolution(pre_spikes, post_spikes, t_log, w_log, Kpre_log, Kpost_log,
fname_snip=fname_snip + "_ref", title_snip="Reference")

return t_log, w_log
return pre_spike_times, [w_log[i] for i, t in enumerate(t_log) if t in pre_spike_times]

def plot_weight_evolution(self, pre_spikes, post_spikes, t_log, w_log, Kpre_log=None, Kpost_log=None,
fname_snip="", title_snip=""):
def plot_weight_evolution(self, pre_spikes, post_spikes, t_log, w_log, Kpre_log=None, Kpost_log=None, fname_snip="",
title_snip=""):
fig, ax = plt.subplots(nrows=3)

n_spikes = len(pre_spikes)
Expand Down
5 changes: 3 additions & 2 deletions testsuite/regressiontests/issue-77.sli
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ M_ERROR setverbosity
/receptor_type 1 >>
/iaf_psc_exp_multisynapse << /params << /tau_syn [ 1.0 ] >>
/receptor_type 1 >>
/aeif_cond_alpha_multisynapse << /params << /tau_syn [ 2.0 ] >>
/aeif_cond_alpha_multisynapse << /params << /E_rev [ -20.0 ]
/tau_syn [ 2.0 ] >>
/receptor_type 1 >>
/aeif_cond_beta_multisynapse << /params << /E_rev [ 0.0 ]
/aeif_cond_beta_multisynapse << /params << /E_rev [ -20.0 ]
/tau_rise [ 1.0 ]
/tau_decay [ 2.0 ] >>
/receptor_type 1 >>
Expand Down

0 comments on commit b3ba799

Please sign in to comment.