Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix STDP k-value error for edge case #2443

Merged
merged 16 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions nestkernel/archiving_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,13 @@ nest::ArchivingNode::set_spiketime( Time const& t_sp, double offset )
// - its access counter indicates it has been read out by all connected
// STDP synapses, and
// - there is another, later spike, that is strictly more than
// (max_delay_ + eps) away from the new spike (at t_sp_ms)
// (min_global_delay + max_local_delay + eps) away from the new spike (at t_sp_ms)
while ( history_.size() > 1 )
{
const double next_t_sp = history_[ 1 ].t_;
if ( history_.front().access_counter_ >= n_incoming_
and t_sp_ms - next_t_sp > max_delay_ + kernel().connection_manager.get_stdp_eps() )
and t_sp_ms - next_t_sp
clinssen marked this conversation as resolved.
Show resolved Hide resolved
> max_delay_ + kernel().connection_manager.get_min_delay() + kernel().connection_manager.get_stdp_eps() )
{
history_.pop_front();
}
Expand Down
156 changes: 80 additions & 76 deletions testsuite/pytests/test_stdp_synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
try:
import matplotlib as mpl
import matplotlib.pyplot as plt

DEBUG_PLOTS = True
except Exception:
DEBUG_PLOTS = False
Expand All @@ -42,11 +43,11 @@ class TestSTDPSynapse:
"""

def init_params(self):
self.resolution = 0.1 # [ms]
self.simulation_duration = 1E3 # [ms]
self.resolution = 0.1 # [ms]
self.simulation_duration = 1E3 # [ms]
self.synapse_model = "stdp_synapse"
self.presynaptic_firing_rate = 20. # [ms^-1]
self.postsynaptic_firing_rate = 20. # [ms^-1]
self.presynaptic_firing_rate = 20. # [ms^-1]
self.postsynaptic_firing_rate = 20. # [ms^-1]
self.tau_pre = 16.8
self.tau_post = 33.7
self.init_weight = .5
Expand All @@ -65,44 +66,37 @@ def init_params(self):
}
self.neuron_parameters = {
"tau_minus": self.tau_post,
"t_ref": 1.0
}

# While the random sequences, fairly long, would supposedly
# reveal small differences in the weight change between NEST
# and ours, some low-probability events (say, coinciding
# spikes) can well not have occurred. To generate and
# test every possible combination of pre/post order, we
# append some hardcoded spike sequences:
# pre: 1 5 6 7 9 11 12 13
# post: 2 3 4 8 9 10 12
self.hardcoded_pre_times = np.array([1, 5, 6, 7, 9, 11, 12, 13], dtype=float)
self.hardcoded_post_times = np.array([2, 3, 4, 8, 9, 10, 12], dtype=float)
self.hardcoded_trains_length = 2. + max(np.amax(self.hardcoded_pre_times), np.amax(self.hardcoded_post_times))
# append some hardcoded spike sequences
self.hardcoded_pre_times = np.array(
[1, 5, 6, 7, 9, 11, 12, 13, 14.5, 16.1], dtype=float)
self.hardcoded_post_times = np.array(
[2, 3, 4, 8, 9, 10, 12, 13.2, 15.1, 16.4], dtype=float)
self.hardcoded_trains_length = 2. + max(
np.amax(self.hardcoded_pre_times),
np.amax(self.hardcoded_post_times))

def do_nest_simulation_and_compare_to_reproduced_weight(self, fname_snip):
pre_spikes, post_spikes, t_weight_by_nest, weight_by_nest = self.do_the_nest_simulation()
if DEBUG_PLOTS:
self.plot_weight_evolution(pre_spikes, post_spikes,
t_weight_by_nest,
weight_by_nest,
fname_snip=fname_snip,
self.plot_weight_evolution(pre_spikes, post_spikes, t_weight_by_nest, weight_by_nest, fname_snip=fname_snip,
title_snip=self.nest_neuron_model + " (NEST)")

t_weight_reproduced_independently, weight_reproduced_independently = self.reproduce_weight_drift(
pre_spikes, post_spikes,
self.init_weight,
fname_snip=fname_snip)
t_weight_reproduced_independently, weight_reproduced_independently = \
self.reproduce_weight_drift(pre_spikes, post_spikes, self.init_weight, fname_snip=fname_snip)

# ``weight_by_nest`` contains only weight values at pre spike times, ``weight_reproduced_independently``
# contains the weight at pre *and* post times: check that weights are equal only for pre spike times
# contains the weight at pre *and* post times: check that weights are equal for pre spike times
assert len(weight_by_nest) > 0
for idx_pre_spike_nest, t_pre_spike_nest in enumerate(t_weight_by_nest):
idx_pre_spike_reproduced_independently = \
np.argmin((t_pre_spike_nest - t_weight_reproduced_independently)**2)
np.testing.assert_allclose(t_pre_spike_nest,
t_weight_reproduced_independently[idx_pre_spike_reproduced_independently])
np.testing.assert_allclose(weight_by_nest[idx_pre_spike_nest],
weight_reproduced_independently[idx_pre_spike_reproduced_independently])
np.testing.assert_allclose(t_weight_by_nest, t_weight_reproduced_independently)
np.testing.assert_allclose(weight_by_nest, weight_reproduced_independently)

def do_the_nest_simulation(self):
"""
Expand Down Expand Up @@ -134,21 +128,25 @@ def do_the_nest_simulation(self):
spike_senders = nest.Create(
"spike_generator",
2,
params=({"spike_times": self.hardcoded_pre_times
+ self.simulation_duration - self.hardcoded_trains_length},
{"spike_times": self.hardcoded_post_times
+ self.simulation_duration - self.hardcoded_trains_length})
params=({"spike_times": self.hardcoded_pre_times +
self.simulation_duration - self.hardcoded_trains_length},
{"spike_times": self.hardcoded_post_times +
self.simulation_duration - self.hardcoded_trains_length})
)
pre_spike_generator = spike_senders[0]
post_spike_generator = spike_senders[1]

# The recorder is to save the randomly generated spike trains.
spike_recorder = nest.Create("spike_recorder")

nest.Connect(presynaptic_generator + pre_spike_generator, presynaptic_neuron,
syn_spec={"synapse_model": "static_synapse", "weight": 9999.})
nest.Connect(postsynaptic_generator + post_spike_generator, postsynaptic_neuron,
syn_spec={"synapse_model": "static_synapse", "weight": 9999.})
nest.Connect(presynaptic_generator + pre_spike_generator,
presynaptic_neuron,
syn_spec={"synapse_model": "static_synapse",
"weight": 9999.})
nest.Connect(postsynaptic_generator + post_spike_generator,
postsynaptic_neuron,
syn_spec={"synapse_model": "static_synapse",
"weight": 9999.})
nest.Connect(presynaptic_neuron + postsynaptic_neuron, spike_recorder,
syn_spec={"synapse_model": "static_synapse"})
# The synapse of interest itself
Expand All @@ -169,19 +167,20 @@ def do_the_nest_simulation(self):

def reproduce_weight_drift(self, pre_spikes, post_spikes, initial_weight, fname_snip=""):
"""Independent, self-contained model of STDP"""

def facilitate(w, Kpre, Wmax_=1.):
norm_w = (w / self.synapse_parameters["Wmax"]) + (
self.synapse_parameters["lambda"] * pow(
1 - (w / self.synapse_parameters["Wmax"]), self.synapse_parameters["mu_plus"]) * Kpre)
self.synapse_parameters["lambda"] * pow(1 - (w / self.synapse_parameters["Wmax"]),
self.synapse_parameters["mu_plus"]) * Kpre)
if norm_w < 1.0:
return norm_w * self.synapse_parameters["Wmax"]
else:
return self.synapse_parameters["Wmax"]

def depress(w, Kpost):
norm_w = (w / self.synapse_parameters["Wmax"]) - (
self.synapse_parameters["alpha"] * self.synapse_parameters["lambda"] * pow(
w / self.synapse_parameters["Wmax"], self.synapse_parameters["mu_minus"]) * Kpost)
norm_w = (w / self.synapse_parameters["Wmax"]) - \
(self.synapse_parameters["alpha"] * self.synapse_parameters["lambda"] *
pow(w / self.synapse_parameters["Wmax"], self.synapse_parameters["mu_minus"]) * Kpost)
if norm_w > 0.0:
return norm_w * self.synapse_parameters["Wmax"]
else:
Expand Down Expand Up @@ -209,42 +208,47 @@ def Kpost_at_time(t, spikes, inclusive=True):
Kpost *= exp(-(t - t_curr) / self.tau_post)
return Kpost

eps = 1e-6
t = 0.
idx_next_pre_spike = 0
idx_next_post_spike = 0
t_last_pre_spike = -1
t_last_post_spike = -1
Kpre = 0.
weight = initial_weight

t_log = []
w_log = []
Kpre_log = []

# logging
t_log.append(t)
w_log.append(weight)
Kpre_log.append(Kpre)
t_log = list()
w_log = list()
Kpre_log = list()
pre_spike_times = list()

post_spikes_delayed = post_spikes + self.dendritic_delay

while t < self.simulation_duration:
idx_next_pre_spike = -1
if np.where((pre_spikes - t) > 0)[0].size > 0:
idx_next_pre_spike = np.where((pre_spikes - t) > 0)[0][0]
if idx_next_pre_spike >= pre_spikes.size:
t_next_pre_spike = -1
else:
t_next_pre_spike = pre_spikes[idx_next_pre_spike]

idx_next_post_spike = -1
if np.where((post_spikes_delayed - t) > 0)[0].size > 0:
idx_next_post_spike = np.where((post_spikes_delayed - t) > 0)[0][0]
if idx_next_post_spike >= post_spikes.size:
t_next_post_spike = -1
else:
t_next_post_spike = post_spikes_delayed[idx_next_post_spike]

if idx_next_pre_spike >= 0 and idx_next_post_spike >= 0 and t_next_post_spike < t_next_pre_spike:
handle_post_spike = True
if t_next_post_spike >= 0 and (t_next_post_spike + eps < t_next_pre_spike or t_next_pre_spike < 0):
handle_pre_spike = False
elif idx_next_pre_spike >= 0 and idx_next_post_spike >= 0 and t_next_post_spike > t_next_pre_spike:
handle_post_spike = False
handle_post_spike = True
idx_next_post_spike += 1
elif t_next_pre_spike >= 0 and (t_next_post_spike > t_next_pre_spike + eps or t_next_post_spike < 0):
handle_pre_spike = True
handle_post_spike = False
idx_next_pre_spike += 1
else:
# simultaneous spikes (both true) or no more spikes to process (both false)
handle_post_spike = idx_next_post_spike >= 0
handle_pre_spike = idx_next_pre_spike >= 0
handle_pre_spike = t_next_pre_spike >= 0
handle_post_spike = t_next_post_spike >= 0
idx_next_pre_spike += 1
idx_next_post_spike += 1

# integrate to min(t_next_pre_spike, t_next_post_spike)
t_next = t
Expand All @@ -257,40 +261,40 @@ def Kpost_at_time(t, spikes, inclusive=True):
# no more spikes to process
t_next = self.simulation_duration

'''# max timestep
t_next_ = min(t_next, t + 1E-3)
if t_next != t_next_:
t_next = t_next_
handle_pre_spike = False
handle_post_spike = False'''

h = t_next - t
Kpre *= exp(-h / self.tau_pre)
t = t_next

if handle_post_spike:
# Kpost += 1. <-- not necessary, will call Kpost_at_time(t) later to compute Kpost for any value t
weight = facilitate(weight, Kpre)
if not handle_pre_spike or abs(t_next_post_spike - t_last_post_spike) > eps:
if abs(t_next_post_spike - t_last_pre_spike) > eps:
weight = facilitate(weight, Kpre)

if handle_pre_spike:
Kpre += 1.
_Kpost = Kpost_at_time(t - self.dendritic_delay, post_spikes, inclusive=False)
weight = depress(weight, _Kpost)
if not handle_post_spike or abs(t_next_pre_spike - t_last_pre_spike) > eps:
if abs(t_next_pre_spike - t_last_post_spike) > eps:
_Kpost = Kpost_at_time(t - self.dendritic_delay, post_spikes, inclusive=False)
weight = depress(weight, _Kpost)
t_last_pre_spike = t_next_pre_spike
pre_spike_times.append(t)

if handle_post_spike:
t_last_post_spike = t_next_post_spike

# logging
t_log.append(t)
w_log.append(weight)
Kpre_log.append(Kpre)
w_log.append(weight)
t_log.append(t)

Kpost_log = [Kpost_at_time(t - self.dendritic_delay, post_spikes) for t in t_log]
if DEBUG_PLOTS:
self.plot_weight_evolution(pre_spikes, post_spikes, t_log, w_log, Kpre_log, Kpost_log,
fname_snip=fname_snip + "_ref", title_snip="Reference")

return t_log, w_log
return pre_spike_times, [w_log[i] for i, t in enumerate(t_log) if t in pre_spike_times]

def plot_weight_evolution(self, pre_spikes, post_spikes, t_log, w_log, Kpre_log=None, Kpost_log=None,
fname_snip="", title_snip=""):
def plot_weight_evolution(self, pre_spikes, post_spikes, t_log, w_log, Kpre_log=None, Kpost_log=None, fname_snip="",
title_snip=""):
fig, ax = plt.subplots(nrows=3)

n_spikes = len(pre_spikes)
Expand Down
5 changes: 3 additions & 2 deletions testsuite/regressiontests/issue-77.sli
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ M_ERROR setverbosity
/receptor_type 1 >>
/iaf_psc_exp_multisynapse << /params << /tau_syn [ 1.0 ] >>
/receptor_type 1 >>
/aeif_cond_alpha_multisynapse << /params << /tau_syn [ 2.0 ] >>
/aeif_cond_alpha_multisynapse << /params << /E_rev [ -20.0 ]
/tau_syn [ 2.0 ] >>
/receptor_type 1 >>
/aeif_cond_beta_multisynapse << /params << /E_rev [ 0.0 ]
/aeif_cond_beta_multisynapse << /params << /E_rev [ -20.0 ]
/tau_rise [ 1.0 ]
/tau_decay [ 2.0 ] >>
/receptor_type 1 >>
Expand Down