Skip to content

Commit

Permalink
Merge pull request #23 from bioAI-Oslo/mikkel
Browse files Browse the repository at this point in the history
Recurrent scaling in BernoulliGLM
  • Loading branch information
JakobSonstebo authored Sep 17, 2023
2 parents c6b4961 + 79b0466 commit 4259dd1
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 27 deletions.
5 changes: 3 additions & 2 deletions docs/introduction/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@ in the :class:`BernoulliGLM` class, using the same parameters as in the original
.. code-block:: python
model = BernoulliGLM(
alpha= 0.2, # Decay rate of the coupling strength between neurons (1/ms)
beta= 0.5, # Decay rate of the self-inhibition during the relative refractory period (1/ms)
alpha=0.2, # Decay rate of the coupling strength between neurons (1/ms)
beta=0.5, # Decay rate of the self-inhibition during the relative refractory period (1/ms)
abs_ref_scale=3, # Absolute refractory period in time steps
rel_ref_scale=7, # Relative refractory period in time steps
abs_ref_strength=-100, # Strength of the self-inhibition during the absolute refractory period
rel_ref_strength=-30, # Initial strength of the self-inhibition during the relative refractory period
coupling_window=5, # Length of coupling window in time steps
theta=5, # Threshold for firing
r=1, # Parameter controlling the recurrence strength
dt=1, # Length of time step (ms)
)
Expand Down
3 changes: 3 additions & 0 deletions docs/tutorials/stimuli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ After we've defined the model and the stimulus, we can simulate the network and
rel_ref_strength=-30,
alpha=0.2,
beta=0.5,
r=1
)
# Define stimulus and add it to the model
Expand Down Expand Up @@ -92,6 +93,7 @@ Before we add the stimulus to the model, we'll run a simulation without it to se
rel_ref_strength=-30,
alpha=0.2,
beta=0.5,
r=1
)
spikes = model.simulate(network, n_steps=n_steps)
Expand Down Expand Up @@ -156,6 +158,7 @@ that is close to the frequency of the stimulus.
rel_ref_strength=-30,
alpha=0.2,
beta=0.5,
r=1
)
stimulus = SinStimulus(
Expand Down
3 changes: 2 additions & 1 deletion examples/large_scale_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@
" rel_ref_scale=5,\n",
" rel_ref_strength=-30,\n",
" beta=0.1,\n",
" r=1\n",
")"
]
},
Expand Down Expand Up @@ -842,7 +843,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.11"
},
"orig_nbformat": 4
},
Expand Down
25 changes: 14 additions & 11 deletions examples/simulate_with_stimulus.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions examples/working_with_stimulus.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
" rel_ref_strength=-30, \n",
" alpha=0.2,\n",
" beta=0.5,\n",
" r=1,\n",
")\n",
"model.add_stimulus(stim)\n",
"\n",
Expand Down Expand Up @@ -233,6 +234,7 @@
" rel_ref_strength=-30, \n",
" alpha=0.2,\n",
" beta=0.5,\n",
" r=1,\n",
")"
]
},
Expand Down Expand Up @@ -428,6 +430,7 @@
" rel_ref_strength=-30, \n",
" alpha=0.2,\n",
" beta=0.5,\n",
" r=1,\n",
")\n",
"\n",
"\n",
Expand Down
28 changes: 16 additions & 12 deletions spikeometric/models/bernoulli_glm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ class BernoulliGLM(BaseModel):
More formally, the model can be broken into three steps, each of which is implemented as a separate method in this class:
#. .. math:: g_i(t+1) = \sum_{\tau=0}^{T-1} \left(X_i(t-\tau)r(\tau) + \sum_{j \in \mathcal{N}(i)} (W_0)_{j, i} X_j(t-\tau) c(\tau)\right) + \mathcal{E}_i(t+1)
#. .. math:: g_i(t+1) = \sum_{\tau=0}^{T-1} \left(X_i(t-\tau)ref(\tau) + r\sum_{j \in \mathcal{N}(i)} (W_0)_{j, i} X_j(t-\tau) c(\tau)\right) + \mathcal{E}_i(t+1)
#. .. math:: p_i(t+1) = \sigma(g_i(t+1) - \theta) \Delta t
#. .. math:: X_i(t+1) \sim \text{Bernoulli}(p_i(t+1))
The first equation is implemented in the :meth:`input` method and gives us the input to the neuron :math:`i` at time :math:`t+1` as a sum of the refractory, synaptic and external inputs.
The refractory input is calculated by convolving the spike history of the neuron itself with a refractory filter :math:`r`, the synaptic input is obtained by convolving the spike history
The refractory input is calculated by convolving the spike history of the neuron itself with a refractory filter :math:`ref`, the synaptic input is obtained by convolving the spike history
of the neuron's neighbors with the coupling filter :math:`c`, weighted by the synaptic weights :math:`W_0`, and the exteral input is given by evaluating an external input function :math:`\mathcal{E}` at time :math:`t+1`.
The second equation is implemented in :meth:`non_linearity` which computes the probability that the neuron :math:`i` spikes at time :math:`t+1` by passing
Expand All @@ -40,13 +40,15 @@ class BernoulliGLM(BaseModel):
abs_ref_scale : int
The absolute refractory period of the neurons :math:`A_{ref}` in time steps
abs_ref_strength : float
The large negative activation :math:`a` added to the neurons during the absolute refractory period
The large negative activation :math:`abs` added to the neurons during the absolute refractory period
rel_ref_scale : int
The relative refractory period of the neurons :math:`R_{ref}` in time steps
rel_ref_strength : float
The negative activation :math:`r` added to the neurons during the relative refractory period (tunable)
The negative activation :math:`rel` added to the neurons during the relative refractory period (tunable)
beta : float
The decay rate :math:`\beta` of the weights. (tunable)
r : float
The scaling of the recurrent connections. (tunable)
rng : torch.Generator
The random number generator for sampling from the Bernoulli distribution.
"""
Expand All @@ -61,6 +63,7 @@ def __init__(self,
rel_ref_scale: int,
rel_ref_strength: int,
beta: float,
r: float,
rng=None
):
super().__init__()
Expand All @@ -76,6 +79,7 @@ def __init__(self,

# Parameters are used to store tensors that will be tunable
self.register_parameter("theta", nn.Parameter(torch.tensor(theta, dtype=torch.float)))
self.register_parameter("r", torch.nn.Parameter(torch.tensor(r, dtype=torch.float)))
self.register_parameter("beta", nn.Parameter(torch.tensor(beta, dtype=torch.float)))
self.register_parameter("alpha", nn.Parameter(torch.tensor(alpha, dtype=torch.float)))
self.register_parameter("rel_ref_strength", nn.Parameter(torch.tensor(rel_ref_strength, dtype=torch.float)))
Expand All @@ -88,7 +92,7 @@ def input(self, edge_index: torch.Tensor, W: torch.Tensor, state: torch.Tensor,
Computes the input at time step :obj:`t+1` by adding together the synaptic input from neighboring neurons and the stimulus input.
.. math::
g_i(t+1) = \sum_{\tau=0}^{T-1} \left(X_i(t-\tau)r(\tau) + \sum_{j \in \mathcal{N}(i)} (W_0)_{j, i} X_j(t-\tau) c(\tau)\right) + \mathcal{E}_i(t+1)
g_i(t+1) = \sum_{\tau=0}^{T-1} \left(X_i(t-\tau)ref(\tau) + \sum_{j \in \mathcal{N}(i)} (W_0)_{j, i} X_j(t-\tau) c(\tau)\right) + \mathcal{E}_i(t+1)
Parameters
----------
Expand All @@ -106,7 +110,7 @@ def input(self, edge_index: torch.Tensor, W: torch.Tensor, state: torch.Tensor,
synaptic_input : torch.Tensor [n_neurons, 1]
"""
return self.synaptic_input(edge_index, W, state=state) + self.stimulus_input(t)
return self.r * self.synaptic_input(edge_index, W, state=state) + self.stimulus_input(t)

def non_linearity(self, input: torch.Tensor) -> torch.Tensor:
r"""
Expand Down Expand Up @@ -150,7 +154,7 @@ def connectivity_filter(self, W0: torch.Tensor, edge_index: torch.Tensor) -> tor
r"""
The connectivity filter constructs a tensor holding the weights of the edges in the network.
This is done by filtering the initial coupling weights :math:`W_0` with the coupling filter :math:`c`
and using a refractory filter :math:`r` as self-edge weights to emulate the refractory period.
and using a refractory filter :math:`ref` as self-edge weights to emulate the refractory period.
For the coupling edges, we are given an initial weight :math:`(W_0)_{i,j}` for each edge. This
tells us how strong the connection between neurons :math:`i` and :math:`j` is immediately after a spike event.
Expand All @@ -172,16 +176,16 @@ def connectivity_filter(self, W0: torch.Tensor, edge_index: torch.Tensor) -> tor
This is modeled by weighting spike events by to :math:`r e^{-\alpha t \Delta t}` for
the next :math:`R_{ref}` time steps.
That is, the refractory filter :math:`r` is given by
That is, the refractory filter :math:`ref` is given by
.. math::
r(t) = \begin{cases}
a & \text{if } t < A_{ref} \\
r e^{-\alpha t \Delta t} & \text{if } A_{ref} \leq t < A_{ref} + R_{ref} \\
ref(t) = \begin{cases}
abs & \text{if } t < A_{ref} \\
rel e^{-\alpha t \Delta t} & \text{if } A_{ref} \leq t < A_{ref} + R_{ref} \\
0 & \text{if } A_{ref} + R_{ref} \leq t
\end{cases}
And we set `W_{i, i}(t) = r(t)` for all neurons :math:`i`.
And we set `W_{i, i}(t) = ref(t)` for all neurons :math:`i`.
All of this information can be represented by a tensor :math:`W` of shape :math:`N\times N\times T`, where
:code:`W[i, j, t]` is the weight of the edge from neuron :math:`i` to neuron :math:`j` at time step :math:`t` after a spike event.
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def bernoulli_glm():
rel_ref_scale=7,
rel_ref_strength=-30.,
alpha=0.2,
r=1,
rng=rng,
)
return model
Expand Down
Binary file modified tests/test_data/stim_plan.pt
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_save_load(bernoulli_glm):
from spikeometric.models import BernoulliGLM
with NamedTemporaryFile() as f:
bernoulli_glm.save(f.name)
loaded_model = BernoulliGLM(1, 1, 1, 1, 1, 1, 1, 1, 1)
loaded_model = BernoulliGLM(1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
loaded_model.load(f.name)
for param, loaded_param in zip(bernoulli_glm.parameters(), loaded_model.parameters()):
assert_close(param, loaded_param)
Expand Down

0 comments on commit 4259dd1

Please sign in to comment.