diff --git a/README.md b/README.md
index f160b3ce..b6ce35f4 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,8 @@ A Python package used for simulating spiking neural networks (SNNs) on CPUs or G
BindsNET is a spiking neural network simulation library geared towards the development of biologically inspired algorithms for machine learning.
-This package is used as part of ongoing research on applying SNNs to machine learning (ML) and reinforcement learning (RL) problems in the [Biologically Inspired Neural & Dynamical Systems (BINDS) lab](http://binds.cs.umass.edu/).
+This package is used as part of ongoing research on applying SNNs, machine learning (ML) and reinforcement learning (RL) problems in the [Biologically Inspired Neural & Dynamical Systems (BINDS) lab](http://binds.cs.umass.edu/) and the Allen Discovery Center at Tufts University.
+
Check out the [BindsNET examples](https://github.com/BindsNET/bindsnet/tree/master/examples) for a collection of experiments, functions for the analysis of results, plots of experiment outcomes, and more. Documentation for the package can be found [here](https://bindsnet-docs.readthedocs.io).
@@ -129,11 +130,22 @@ If you use BindsNET in your research, please cite the following [article](https:
## Contributors
-- Daniel Saunders ([email](mailto:djsaunde@cs.umass.edu))
-- Hananel Hazan ([email](mailto:hananel@hazan.org.il))
-- Darpan Sanghavi ([email](mailto:dsanghavi@cs.umass.edu))
-- Hassaan Khan ([email](mailto:hqkhan@umass.edu))
-- Devdhar Patel ([email](mailto:devdharpatel@cs.umass.edu))
+- Hava Siegelmann - Director of BINDS lab at UMass
+- Robert Kozma - Co-Director of BINDS lab (2018-2019)
+- Hananel Hazan ([email](mailto:hananel@hazan.org.il)) - Spearheaded BindsNET and its main maintainer.
+- Daniel Saunders ([email](mailto:djsaunde@cs.umass.edu)) - MSc student, BindsNET core functions coder (2018-2019).
+- Darpan Sanghavi ([email](mailto:dsanghavi@cs.umass.edu)) - MSc student BINDS Lab (2018)
+- Hassaan Khan ([email](mailto:hqkhan@umass.edu)) - MSc student BINDS Lab (2018)
+- Devdhar Patel ([email](mailto:devdharpatel@cs.umass.edu)) - MSc student BINDS Lab (2018)
+- Simon Caby [github](https://github.com/SimonInParis)
+- Christopher Earl ([email](mailto:cearl@umass.edu)) - MSc student (2021 - present)
+
+
+
+
+
+
+Made with [contrib.rocks](https://contrib.rocks).
## License
GNU Affero General Public License v3.0
diff --git a/bindsnet/analysis/plotting.py b/bindsnet/analysis/plotting.py
index bb5a39c1..2c2ebc04 100644
--- a/bindsnet/analysis/plotting.py
+++ b/bindsnet/analysis/plotting.py
@@ -703,7 +703,7 @@ def plot_voltages(
.numpy()[
time[0] : time[1],
n_neurons[v[0]][0] : n_neurons[v[0]][1],
- ]
+ ],
)
)
diff --git a/bindsnet/learning/MCC_learning.py b/bindsnet/learning/MCC_learning.py
new file mode 100644
index 00000000..14565a80
--- /dev/null
+++ b/bindsnet/learning/MCC_learning.py
@@ -0,0 +1,721 @@
+from abc import ABC, abstractmethod
+from typing import Union, Optional, Sequence
+import warnings
+
+import torch
+import numpy as np
+
+from ..network.nodes import SRM0Nodes
+from ..network.topology import (
+ AbstractMulticompartmentConnection,
+ MulticompartmentConnection,
+)
+from ..utils import im2col_indices
+
+
+class MCC_LearningRule(ABC):
+ # language=rst
+ """
+ Abstract base class for learning rules.
+ """
+
+ def __init__(
+ self,
+ connection: AbstractMulticompartmentConnection,
+ feature_value: Union[float, int, torch.Tensor],
+ range: Optional[Union[list, tuple]] = None,
+ nu: Optional[Union[float, Sequence[float]]] = None,
+ reduction: Optional[callable] = None,
+ decay: float = 0.0,
+ enforce_polarity: bool = False,
+ **kwargs,
+ ) -> None:
+ # language=rst
+ """
+ Abstract constructor for the ``LearningRule`` object.
+
+ :param connection: An ``AbstractConnection`` object.
+ :param feature_value: Value(s) to be updated. Can be only tensor (scalar currently not supported)
+ :param range: Allowed range for :code:`feature_value`
+ :param nu: Single or pair of learning rates for pre- and post-synaptic events.
+ :param reduction: Method for reducing parameter updates along the batch
+ dimension.
+ :param decay: Coefficient controlling rate of decay of the weights each iteration.
+ :param enforce_polarity: Will prevent synapses from changing signs if :code:`True`
+ """
+ # Connection parameters.
+ self.connection = connection
+ self.source = connection.source
+ self.target = connection.target
+ self.feature_value = feature_value
+ self.enforce_polarity = enforce_polarity
+ self.min, self.max = range
+
+ # Learning rate(s).
+ if nu is None:
+ nu = [0.2, 0.1]
+ elif isinstance(nu, (float, int)):
+ nu = [nu, nu]
+
+ # Keep track of polarities
+ if enforce_polarity:
+ self.polarities = torch.sign(self.feature_value)
+
+ self.nu = torch.zeros(2, dtype=torch.float)
+ self.nu[0] = nu[0]
+ self.nu[1] = nu[1]
+
+ if (self.nu == torch.zeros(2)).all() and not isinstance(self, NoOp):
+ warnings.warn(
+ f"nu is set to [0., 0.] for {type(self).__name__} learning rule. "
+ + "It will disable the learning process."
+ )
+
+ # Parameter update reduction across minibatch dimension.
+ if reduction is None:
+ if self.source.batch_size == 1:
+ self.reduction = torch.squeeze
+ else:
+ self.reduction = torch.sum
+ else:
+ self.reduction = reduction
+
+ # Weight decay.
+ self.decay = 1.0 - decay if decay else 1.0
+
+ def update(self, **kwargs) -> None:
+ # language=rst
+ """
+ Abstract method for a learning rule update.
+ """
+
+ # Implement decay.
+ if self.decay:
+ self.feature_value *= self.decay
+
+ # Enforce polarities
+ if self.enforce_polarity:
+ polarity_swaps = self.polarities == torch.sign(self.feature_value)
+ self.feature_value[polarity_swaps == 0] = 0
+
+ # Bound weights.
+ if ((self.min is not None) or (self.max is not None)) and not isinstance(
+ self, NoOp
+ ):
+ self.feature_value.clamp_(self.min, self.max)
+
+ @abstractmethod
+ def reset_state_variables(self) -> None:
+ # language=rst
+ """
+ Contains resetting logic for the feature.
+ """
+ pass
+
+
+class NoOp(MCC_LearningRule):
+ # language=rst
+ """
+ Learning rule with no effect.
+ """
+
+ def __init__(self, **args) -> None:
+ # language=rst
+ """
+ No operation done during runtime
+ """
+ pass
+
+ def update(self, **kwargs) -> None:
+ # language=rst
+ """
+ No operation done during runtime
+ """
+ pass
+
+ def reset_state_variables(self) -> None:
+ # language=rst
+ """
+ Contains resetting logic for the feature.
+ """
+ pass
+
+
+class PostPre(MCC_LearningRule):
+ # language=rst
+ """
+ Simple STDP rule involving both pre- and post-synaptic spiking activity. By default,
+ pre-synaptic update is negative and the post-synaptic update is positive.
+ """
+
+ def __init__(
+ self,
+ connection: AbstractMulticompartmentConnection,
+ feature_value: Union[torch.Tensor, float, int],
+ range: Optional[Sequence[float]] = None,
+ nu: Optional[Union[float, Sequence[float]]] = None,
+ reduction: Optional[callable] = None,
+ decay: float = 0.0,
+ enforce_polarity: bool = False,
+ **kwargs,
+ ) -> None:
+ # language=rst
+ """
+ Constructor for ``PostPre`` learning rule.
+
+ :param connection: An ``AbstractConnection`` object whose weights the
+ ``PostPre`` learning rule will modify.
+ :param feature_value: The object which will be altered
+ :param range: The domain for the feature
+ :param nu: Single or pair of learning rates for pre- and post-synaptic events.
+ :param reduction: Method for reducing parameter updates along the batch
+ dimension.
+ :param decay: Coefficient controlling rate of decay of the weights each iteration.
+ :param enforce_polarity: Will prevent synapses from changing signs if :code:`True`
+
+ Keyword arguments:
+ :param average_update: Number of updates to average over, 0=No averaging, x=average over last x updates
+ :param continues_update: If True, the update will be applied after every update, if False, only after the average_update buffer is full
+ """
+ super().__init__(
+ connection=connection,
+ feature_value=feature_value,
+ range=[-1, +1] if range is None else range,
+ nu=nu,
+ reduction=reduction,
+ decay=decay,
+ enforce_polarity=enforce_polarity,
+ **kwargs,
+ )
+
+ assert self.source.traces and self.target.traces, (
+ "Both pre- and post-synaptic nodes must record spike traces "
+ "(use traces='True' on source/target layers)"
+ )
+
+ if isinstance(connection, (MulticompartmentConnection)):
+ self.update = self._connection_update
+ # elif isinstance(connection, Conv2dConnection):
+ # self.update = self._conv2d_connection_update
+ else:
+ raise NotImplementedError(
+ "This learning rule is not supported for this Connection type."
+ )
+
+ # Initialize variables for average update and continues update
+ self.average_update = kwargs.get("average_update", 0)
+ self.continues_update = kwargs.get("continues_update", False)
+
+ if self.average_update > 0:
+ self.average_buffer_pre = torch.zeros(
+ self.average_update,
+ *self.feature_value.shape,
+ device=self.feature_value.device,
+ )
+ self.average_buffer_post = torch.zeros_like(self.average_buffer_pre)
+ self.average_buffer_index_pre = 0
+ self.average_buffer_index_post = 0
+
+ def _connection_update(self, **kwargs) -> None:
+ # language=rst
+ """
+ Post-pre learning rule for ``Connection`` subclass of ``AbstractConnection``
+ class.
+ """
+ batch_size = self.source.batch_size
+
+ # Pre-synaptic update.
+ if self.nu[0]:
+ source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
+ target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0]
+
+ if self.average_update > 0:
+ self.average_buffer_pre[self.average_buffer_index_pre] = self.reduction(
+ torch.bmm(source_s, target_x), dim=0
+ )
+
+ self.average_buffer_index_pre = (
+ self.average_buffer_index_pre + 1
+ ) % self.average_update
+
+ if self.continues_update:
+ self.feature_value -= (
+ torch.mean(self.average_buffer_pre, dim=0) * self.connection.dt
+ )
+ elif self.average_buffer_index_pre == 0:
+ self.feature_value -= (
+ torch.mean(self.average_buffer_pre, dim=0) * self.connection.dt
+ )
+ else:
+ self.feature_value -= (
+ self.reduction(torch.bmm(source_s, target_x), dim=0)
+ * self.connection.dt
+ )
+ del source_s, target_x
+
+ # Post-synaptic update.
+ if self.nu[1]:
+ target_s = (
+ self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1]
+ )
+ source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
+
+ if self.average_update > 0:
+ self.average_buffer_post[self.average_buffer_index_post] = (
+ self.reduction(torch.bmm(source_x, target_s), dim=0)
+ )
+
+ self.average_buffer_index_post = (
+ self.average_buffer_index_post + 1
+ ) % self.average_update
+
+ if self.continues_update:
+ self.feature_value += (
+ torch.mean(self.average_buffer_post, dim=0) * self.connection.dt
+ )
+ elif self.average_buffer_index_post == 0:
+ self.feature_value += (
+ torch.mean(self.average_buffer_post, dim=0) * self.connection.dt
+ )
+ else:
+ self.feature_value += (
+ self.reduction(torch.bmm(source_x, target_s), dim=0)
+ * self.connection.dt
+ )
+ del source_x, target_s
+
+ super().update()
+
+ def reset_state_variables(self):
+ return
+
+ class Hebbian(MCC_LearningRule):
+ # language=rst
+ """
+ Simple Hebbian learning rule. Pre- and post-synaptic updates are both positive.
+ """
+
+ def __init__(
+ self,
+ connection: AbstractMulticompartmentConnection,
+ feature_value: Union[torch.Tensor, float, int],
+ nu: Optional[Union[float, Sequence[float]]] = None,
+ reduction: Optional[callable] = None,
+ decay: float = 0.0,
+ **kwargs,
+ ) -> None:
+ # language=rst
+ """
+ Constructor for ``Hebbian`` learning rule.
+
+ :param connection: An ``AbstractConnection`` object whose weights the
+ ``Hebbian`` learning rule will modify.
+ :param nu: Single or pair of learning rates for pre- and post-synaptic events.
+ :param reduction: Method for reducing parameter updates along the batch
+ dimension.
+ :param decay: Coefficient controlling rate of decay of the weights each iteration.
+ """
+ super().__init__(
+ connection=connection,
+ feature_value=feature_value,
+ nu=nu,
+ reduction=reduction,
+ decay=decay,
+ **kwargs,
+ )
+
+ assert (
+ self.source.traces and self.target.traces
+ ), "Both pre- and post-synaptic nodes must record spike traces."
+
+ if isinstance(MulticompartmentConnection):
+ self.update = self._connection_update
+ self.feature_value = feature_value
+ # elif isinstance(connection, Conv2dConnection):
+ # self.update = self._conv2d_connection_update
+ else:
+ raise NotImplementedError(
+ "This learning rule is not supported for this Connection type."
+ )
+
+ def _connection_update(self, **kwargs) -> None:
+ # language=rst
+ """
+ Hebbian learning rule for ``Connection`` subclass of ``AbstractConnection``
+ class.
+ """
+
+ # Add polarities back to feature after updates
+ if self.enforce_polarity:
+ self.feature_value = torch.abs(self.feature_value)
+
+ batch_size = self.source.batch_size
+
+ source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
+ source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
+ target_s = self.target.s.view(batch_size, -1).unsqueeze(1).float()
+ target_x = self.target.x.view(batch_size, -1).unsqueeze(1)
+
+ # Pre-synaptic update.
+ update = self.reduction(torch.bmm(source_s, target_x), dim=0)
+ self.feature_value += self.nu[0] * update
+
+ # Post-synaptic update.
+ update = self.reduction(torch.bmm(source_x, target_s), dim=0)
+ self.feature_value += self.nu[1] * update
+
+ # Add polarities back to feature after updates
+ if self.enforce_polarity:
+ self.feature_value = self.feature_value * self.polarities
+
+ super().update()
+
+ def reset_state_variables(self):
+ return
+
+
+class MSTDP(MCC_LearningRule):
+ # language=rst
+ """
+ Reward-modulated STDP. Adapted from `(Florian 2007)
+ `_.
+ """
+
+ def __init__(
+ self,
+ connection: AbstractMulticompartmentConnection,
+ feature_value: Union[torch.Tensor, float, int],
+ range: Optional[Sequence[float]] = None,
+ nu: Optional[Union[float, Sequence[float]]] = None,
+ reduction: Optional[callable] = None,
+ decay: float = 0.0,
+ enforce_polarity: bool = False,
+ **kwargs,
+ ) -> None:
+ # language=rst
+ """
+ Constructor for ``MSTDP`` learning rule.
+
+ :param connection: An ``AbstractConnection`` object whose weights the ``MSTDP``
+ learning rule will modify.
+ :param feature_value: The object which will be altered
+ :param range: The domain for the feature
+ :param nu: Single or pair of learning rates for pre- and post-synaptic events,
+ respectively.
+ :param reduction: Method for reducing parameter updates along the minibatch
+ dimension.
+ :param decay: Coefficient controlling rate of decay of the weights each iteration.
+ :param enforce_polarity: Will prevent synapses from changing signs if :code:`True`
+
+ Keyword arguments:
+
+ :param average_update: Number of updates to average over, 0=No averaging, x=average over last x updates
+ :param continues_update: If True, the update will be applied after every update, if False, only after the average_update buffer is full
+
+ :param tc_plus: Time constant for pre-synaptic firing trace.
+ :param tc_minus: Time constant for post-synaptic firing trace.
+ """
+ super().__init__(
+ connection=connection,
+ feature_value=feature_value,
+ range=[-1, +1] if range is None else range,
+ nu=nu,
+ reduction=reduction,
+ decay=decay,
+ enforce_polarity=enforce_polarity,
+ **kwargs,
+ )
+
+ if isinstance(connection, (MulticompartmentConnection)):
+ self.update = self._connection_update
+ # elif isinstance(connection, Conv2dConnection):
+ # self.update = self._conv2d_connection_update
+ else:
+ raise NotImplementedError(
+ "This learning rule is not supported for this Connection type."
+ )
+
+ self.tc_plus = torch.tensor(kwargs.get("tc_plus", 20.0))
+ self.tc_minus = torch.tensor(kwargs.get("tc_minus", 20.0))
+
+ # Initialize variables for average update and continues update
+ self.average_update = kwargs.get("average_update", 0)
+ self.continues_update = kwargs.get("continues_update", False)
+
+ if self.average_update > 0:
+ self.average_buffer = torch.zeros(
+ self.average_update,
+ *self.feature_value.shape,
+ device=self.feature_value.device,
+ )
+ self.average_buffer_index = 0
+
+ def _connection_update(self, **kwargs) -> None:
+ # language=rst
+ """
+ MSTDP learning rule for ``Connection`` subclass of ``AbstractConnection`` class.
+
+ Keyword arguments:
+
+ :param Union[float, torch.Tensor] reward: Reward signal from reinforcement
+ learning task.
+ :param float a_plus: Learning rate (post-synaptic).
+ :param float a_minus: Learning rate (pre-synaptic).
+ """
+ batch_size = self.source.batch_size
+
+ # Initialize eligibility, P^+, and P^-.
+ if not hasattr(self, "p_plus"):
+ self.p_plus = torch.zeros(
+ # batch_size, *self.source.shape, device=self.source.s.device
+ batch_size,
+ self.source.n,
+ device=self.source.s.device,
+ )
+ if not hasattr(self, "p_minus"):
+ self.p_minus = torch.zeros(
+ # batch_size, *self.target.shape, device=self.target.s.device
+ batch_size,
+ self.target.n,
+ device=self.target.s.device,
+ )
+ if not hasattr(self, "eligibility"):
+ self.eligibility = torch.zeros(
+ batch_size, *self.feature_value.shape, device=self.feature_value.device
+ )
+
+ # Reshape pre- and post-synaptic spikes.
+ source_s = self.source.s.view(batch_size, -1).float()
+ target_s = self.target.s.view(batch_size, -1).float()
+
+ # Parse keyword arguments.
+ reward = kwargs["reward"]
+ a_plus = torch.tensor(
+ kwargs.get("a_plus", 1.0), device=self.feature_value.device
+ )
+ a_minus = torch.tensor(
+ kwargs.get("a_minus", -1.0), device=self.feature_value.device
+ )
+
+ # Compute weight update based on the eligibility value of the past timestep.
+ update = reward * self.eligibility
+
+ if self.average_update > 0:
+ self.average_buffer[self.average_buffer_index] = self.reduction(
+ update, dim=0
+ )
+ self.average_buffer_index = (
+ self.average_buffer_index + 1
+ ) % self.average_update
+
+ if self.continues_update:
+ self.feature_value += self.nu[0] * torch.mean(
+ self.average_buffer, dim=0
+ )
+ elif self.average_buffer_index == 0:
+ self.feature_value += self.nu[0] * torch.mean(
+ self.average_buffer, dim=0
+ )
+ else:
+ self.feature_value += self.nu[0] * self.reduction(update, dim=0)
+
+ # Update P^+ and P^- values.
+ self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
+ self.p_plus += a_plus * source_s
+ self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus)
+ self.p_minus += a_minus * target_s
+
+ # Calculate point eligibility value.
+ self.eligibility = torch.bmm(
+ self.p_plus.unsqueeze(2), target_s.unsqueeze(1)
+ ) + torch.bmm(source_s.unsqueeze(2), self.p_minus.unsqueeze(1))
+
+ super().update()
+
+ def reset_state_variables(self):
+ return
+
+
+class MSTDPET(MCC_LearningRule):
+ # language=rst
+ """
+ Reward-modulated STDP with eligibility trace. Adapted from
+ `(Florian 2007) `_.
+ """
+
+ def __init__(
+ self,
+ connection: AbstractMulticompartmentConnection,
+ feature_value: Union[torch.Tensor, float, int],
+ range: Optional[Sequence[float]] = None,
+ nu: Optional[Union[float, Sequence[float]]] = None,
+ reduction: Optional[callable] = None,
+ decay: float = 0.0,
+ enforce_polarity: bool = False,
+ **kwargs,
+ ) -> None:
+ # language=rst
+ """
+ Constructor for ``MSTDPET`` learning rule.
+
+ :param connection: An ``AbstractConnection`` object whose weights the
+ ``MSTDPET`` learning rule will modify.
+ :param feature_value: The object which will be altered
+ :param range: The domain for the feature
+ :param nu: Single or pair of learning rates for pre- and post-synaptic events,
+ respectively.
+ :param reduction: Method for reducing parameter updates along the minibatch
+ dimension.
+ :param decay: Coefficient controlling rate of decay of the weights each iteration.
+ :param enforce_polarity: Will prevent synapses from changing signs if :code:`True`
+
+ Keyword arguments:
+
+ :param float tc_plus: Time constant for pre-synaptic firing trace.
+ :param float tc_minus: Time constant for post-synaptic firing trace.
+ :param float tc_e_trace: Time constant for the eligibility trace.
+ :param average_update: Number of updates to average over, 0=No averaging, x=average over last x updates
+ :param continues_update: If True, the update will be applied after every update, if False, only after the average_update buffer is full
+
+ """
+ super().__init__(
+ connection=connection,
+ feature_value=feature_value,
+ range=[-1, +1] if range is None else range,
+ nu=nu,
+ reduction=reduction,
+ decay=decay,
+ enforce_polarity=enforce_polarity,
+ **kwargs,
+ )
+
+ if isinstance(connection, (MulticompartmentConnection)):
+ self.update = self._connection_update
+ # elif isinstance(connection, Conv2dConnection):
+ # self.update = self._conv2d_connection_update
+ else:
+ raise NotImplementedError(
+ "This learning rule is not supported for this Connection type."
+ )
+
+ self.tc_plus = torch.tensor(
+ kwargs.get("tc_plus", 20.0)
+ ) # How long pos reinforcement effects weights
+ self.tc_minus = torch.tensor(
+ kwargs.get("tc_minus", 20.0)
+ ) # How long neg reinforcement effects weights
+ self.tc_e_trace = torch.tensor(
+ kwargs.get("tc_e_trace", 25.0)
+ ) # How long trace effects weights
+ self.eligibility = torch.zeros(
+ *self.feature_value.shape, device=self.feature_value.device
+ )
+ self.eligibility_trace = torch.zeros(
+ *self.feature_value.shape, device=self.feature_value.device
+ )
+
+ # Initialize eligibility, eligibility trace, P^+, and P^-.
+ if not hasattr(self, "p_plus"):
+ self.p_plus = torch.zeros((self.source.n), device=self.feature_value.device)
+ if not hasattr(self, "p_minus"):
+ self.p_minus = torch.zeros(
+ (self.target.n), device=self.feature_value.device
+ )
+
+ # Initialize variables for average update and continues update
+ self.average_update = kwargs.get("average_update", 0)
+ self.continues_update = kwargs.get("continues_update", False)
+ if self.average_update > 0:
+ self.average_buffer = torch.zeros(
+ self.average_update,
+ *self.feature_value.shape,
+ device=self.feature_value.device,
+ )
+ self.average_buffer_index = 0
+
+ # @profile
+ def _connection_update(self, **kwargs) -> None:
+ # language=rst
+ """
+ MSTDPET learning rule for ``Connection`` subclass of ``AbstractConnection``
+ class.
+
+ Keyword arguments:
+
+ :param Union[float, torch.Tensor] reward: Reward signal from reinforcement
+ learning task.
+ :param float a_plus: Learning rate (post-synaptic).
+ :param float a_minus: Learning rate (pre-synaptic).
+ """
+ # Reshape pre- and post-synaptic spikes.
+ source_s = self.source.s.view(-1).float()
+ target_s = self.target.s.view(-1).float()
+
+ # Parse keyword arguments.
+ reward = kwargs["reward"]
+ a_plus = kwargs.get("a_plus", 1.0)
+ # if isinstance(a_plus, dict):
+ # for k, v in a_plus.items():
+ # a_plus[k] = torch.tensor(v, device=self.feature_value.device)
+ # else:
+ a_plus = torch.tensor(a_plus, device=self.feature_value.device)
+ a_minus = kwargs.get("a_minus", -1.0)
+ # if isinstance(a_minus, dict):
+ # for k, v in a_minus.items():
+ # a_minus[k] = torch.tensor(v, device=self.feature_value.device)
+ # else:
+ a_minus = torch.tensor(a_minus, device=self.feature_value.device)
+
+ # Calculate value of eligibility trace based on the value
+ # of the point eligibility value of the past timestep.
+ # Note: eligibility = [source.n, target.n] > 0 where source and target spiked
+ # Note: high negs. ->
+ self.eligibility_trace *= torch.exp(
+ -self.connection.dt / self.tc_e_trace
+ ) # Decay
+ self.eligibility_trace += self.eligibility / self.tc_e_trace # Additive changes
+ # ^ Also effected by delay in last step
+
+ # Compute weight update.
+
+ if self.average_update > 0:
+ self.average_buffer[self.average_buffer_index] = (
+ self.nu[0] * self.connection.dt * reward * self.eligibility_trace
+ )
+ self.average_buffer_index = (
+ self.average_buffer_index + 1
+ ) % self.average_update
+
+ if self.continues_update:
+ self.feature_value += torch.mean(self.average_buffer, dim=0)
+ elif self.average_buffer_index == 0:
+ self.feature_value += torch.mean(self.average_buffer, dim=0)
+ else:
+ self.feature_value += (
+ self.nu[0] * self.connection.dt * reward * self.eligibility_trace
+ )
+
+ # Update P^+ and P^- values.
+ self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) # Decay
+ self.p_plus += a_plus * source_s # Scaled source spikes
+ self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus) # Decay
+ self.p_minus += a_minus * target_s # Scaled target spikes
+
+ # Notes:
+ #
+ # a_plus -> How much a spike in src contributes to the eligibility
+ # a_minus -> How much a spike in trg contributes to the eligibility (neg)
+ # p_plus -> +a_plus every spike, with decay
+ # p_minus -> +a_minus every spike, with decay
+
+ # Calculate point eligibility value.
+ self.eligibility = torch.outer(self.p_plus, target_s) + torch.outer(
+ source_s, self.p_minus
+ )
+
+ super().update()
+
+ def reset_state_variables(self) -> None:
+ self.eligibility.zero_()
+ self.eligibility_trace.zero_()
+ return
diff --git a/bindsnet/network/monitors.py b/bindsnet/network/monitors.py
index 7dc7cc3e..f11a2339 100644
--- a/bindsnet/network/monitors.py
+++ b/bindsnet/network/monitors.py
@@ -4,9 +4,17 @@
import numpy as np
import torch
+import numpy as np
+
+from abc import ABC
+from typing import Union, Optional, Iterable, Dict
from bindsnet.network.nodes import Nodes
-from bindsnet.network.topology import AbstractConnection
+from bindsnet.network.topology import (
+ AbstractConnection,
+ AbstractMulticompartmentConnection,
+)
+from bindsnet.network.topology_features import AbstractFeature
if TYPE_CHECKING:
from .network import Network
@@ -27,7 +35,12 @@ class Monitor(AbstractMonitor):
def __init__(
self,
- obj: Union[Nodes, AbstractConnection],
+ obj: Union[
+ Nodes,
+ AbstractConnection,
+ AbstractMulticompartmentConnection,
+ AbstractFeature,
+ ],
state_vars: Iterable[str],
time: Optional[int] = None,
batch_size: int = 1,
@@ -54,8 +67,6 @@ def __init__(
if self.time is None:
self.device = "cpu"
- self.clean = True
-
self.recording = []
self.reset_state_variables()
@@ -179,7 +190,20 @@ def __init__(
self.time, *getattr(self.network.connections[c], v).size()
)
- def get(self) -> Dict[str, Dict[str, Union[Nodes, AbstractConnection]]]:
+ def get(
+ self,
+ ) -> Dict[
+ str,
+ Dict[
+ str,
+ Union[
+ Nodes,
+ AbstractConnection,
+ AbstractMulticompartmentConnection,
+ AbstractFeature,
+ ],
+ ],
+ ]:
# language=rst
"""
Return entire recording to user.
diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py
index 98ac2555..cb5fafa1 100644
--- a/bindsnet/network/topology.py
+++ b/bindsnet/network/topology.py
@@ -1,8 +1,11 @@
from abc import ABC, abstractmethod
from typing import Optional, Sequence, Tuple, Union
+import warnings
+
import numpy as np
import torch
+from torch import device
import torch.nn.functional as F
from bindsnet.utils import im2col_indices
from torch.nn import Module, Parameter
@@ -139,6 +142,112 @@ def reset_state_variables(self) -> None:
"""
+class AbstractMulticompartmentConnection(ABC, Module):
+ # language=rst
+ """
+ Abstract base method for connections between ``Nodes``.
+ """
+
+ def __init__(
+ self,
+ source: Nodes,
+ target: Nodes,
+ device: device,
+ pipeline: list = None,
+ **kwargs,
+ ) -> None:
+ # language=rst
+ """
+ Constructor for abstract base class for connection objects.
+
+ :param source: A layer of nodes from which the connection originates.
+ :param target: A layer of nodes to which the connection connects.
+ :param device: The device which the connection will run on
+ :param pipeline: An ordered list of topology features to be used on the connection
+ """
+
+ super().__init__()
+
+ #### General Assertions ####
+ assert isinstance(source, Nodes), "Source is not a Nodes object"
+ assert isinstance(target, Nodes), "Target is not a Nodes object"
+
+ #### Assign class variables ####
+ self.source = source
+ self.target = target
+ self.device = device
+ self.pipeline = (
+ [] if pipeline is None else pipeline
+ ) # <- *Ordered* executables for features
+
+ # TODO: Make it so there can't be repeated names!!!
+ # Initialize feature index & prime
+ self.feature_index = (
+ {}
+ ) # <- *Unordered* and named set of references for features
+ for feature in pipeline:
+ self.feature_index[feature.name] = feature
+ feature.prime_feature(connection=self, device=self.device, **kwargs)
+
+ @abstractmethod
+ def compute(self, s: torch.Tensor) -> None:
+ # language=rst
+ """
+ Compute pre-activations of downstream neurons given spikes of upstream neurons.
+
+ :param s: Incoming spikes.
+ """
+ pass
+
+ @abstractmethod
+ def update(self, **kwargs) -> None:
+ # language=rst
+ """
+ Compute connection's update rule.
+
+ Keyword arguments:
+
+ :param bool learning: Whether to allow connection updates.
+ """
+ pass
+
+ @abstractmethod
+ def reset_state_variables(self) -> None:
+ # language=rst
+ """
+ Contains resetting logic for the connection.
+ """
+ pass
+
+ def append_pipeline(self, feature) -> None:
+ # language=rst
+ """
+ Append a feature to the pipeline
+ """
+ self.pipeline.append(feature)
+ feature.prime_feature(connection=self, device=self.device)
+ self.feature_index[feature.name] = feature
+
+ def insert_pipeline(self, feature, index) -> None:
+ # language=rst
+ """
+ insert a feature into the pipeline
+ :param index: Index for where to insert the feature
+ """
+ self.pipeline.insert(feature, index)
+ feature.prime_feature(connection=self, device=self.device)
+ self.feature_index[feature.name] = feature
+
+ def remove_pipeline(self, feature) -> None:
+ # language=rst
+ """
+ remove a feature frome the pipeline
+ :param feature: feature to be removed
+ """
+ self.pipeline.remove(feature)
+ del self.feature_index[feature.name]
+
+
class Connection(AbstractConnection):
# language=rst
"""
@@ -272,6 +381,133 @@ def reset_state_variables(self) -> None:
super().reset_state_variables()
+class MulticompartmentConnection(AbstractMulticompartmentConnection):
+ # language=rst
+ """
+ Specifies synapses between one or two populations of neurons.
+ """
+
+ def __init__(
+ self,
+ source: Nodes,
+ target: Nodes,
+ device: device,
+ pipeline: list = [],
+ manual_update: bool = False,
+ traces: bool = False,
+ **kwargs,
+ ) -> None:
+ # language=rst
+ """
+ Instantiates a :code:`Connection` object.
+
+ :param source: A layer of nodes from which the connection originates.
+ :param target: A layer of nodes to which the connection connects.
+ :param device: The device the connection will be run on.
+ :param list: Pipeline of features for the connection signals to be run through
+ :param manual_update: Set to :code:`True` to disable automatic updates (applying learning rules) to connection features.
+ False by default, updates called after each time step
+ :param traces: Set to :code:`True` to record history of connection activity (for monitors)
+ """
+
+ super().__init__(source, target, device, pipeline, **kwargs)
+ self.traces = traces
+ self.manual_update = manual_update
+ if self.traces:
+ self.activity = None
+
+ def compute(self, s: torch.Tensor) -> torch.Tensor:
+ # language=rst
+ """
+ Compute pre-activations given spikes using connection weights.
+
+ :param s: Incoming spikes.
+ :return: Incoming spikes multiplied by synaptic weights (with or without
+ decaying spike activation).
+ """
+
+ # Change to numeric type (torch doesn't like booleans for matrix ops)
+ # Note: .float() is an expensive operation. Use as minimally as possible!
+ # if s.dtype != torch.float32:
+ # s = s.float()
+
+ # Prepare broadcast from incoming spikes to all output neurons
+ # |conn_spikes| = [batch_size, source.n * target.n]
+ conn_spikes = s.view(s.size(0), self.source.n, 1).repeat(1, 1, self.target.n)
+ # TODO: ^ This could probably be optimized
+
+ # Run through pipeline
+ for f in self.pipeline:
+ conn_spikes = f.compute(conn_spikes)
+
+ # Sum signals for each of the output/terminal neurons
+ # |out_signal| = [batch_size, target.n]
+ out_signal = conn_spikes.view(s.size(0), self.source.n, self.target.n).sum(1)
+
+ if self.traces:
+ self.activity = out_signal
+
+ return out_signal.view(s.size(0), *self.target.shape)
+
+ def compute_window(self, s: torch.Tensor) -> torch.Tensor:
+ # language=rst
+ """"""
+
+ if self.s_w == None:
+ # Construct a matrix of shape batch size * window size * dimension of layer
+ self.s_w = torch.zeros(
+ self.target.batch_size, self.target.res_window_size, *self.source.shape
+ )
+
+ # Add the spike vector into the first in first out matrix of windowed (res) spike trains
+ self.s_w = torch.cat((self.s_w[:, 1:, :], s[:, None, :]), 1)
+
+ # Compute multiplication of spike activations by weights and add bias.
+ if self.b is None:
+ post = (
+ self.s_w.view(self.s_w.size(0), self.s_w.size(1), -1).float() @ self.w
+ )
+ else:
+ post = (
+ self.s_w.view(self.s_w.size(0), self.s_w.size(1), -1).float() @ self.w
+ + self.b
+ )
+
+ return post.view(
+ self.s_w.size(0), self.target.res_window_size, *self.target.shape
+ )
+
+ def update(self, **kwargs) -> None:
+ # language=rst
+ """
+ Compute connection's update rule.
+ """
+ learning = kwargs.get("learning", False)
+ if learning and not self.manual_update:
+ # Pipeline learning
+ for f in self.pipeline:
+ f.update(**kwargs)
+
+ def normalize(self) -> None:
+ # language=rst
+ """
+ Normalize all features in the connection.
+ """
+ # Normalize pipeline features
+ for f in self.pipeline:
+ f.normalize()
+
+ def reset_state_variables(self) -> None:
+ # language=rst
+ """
+ Contains resetting logic for the connection.
+ """
+ super().reset_state_variables()
+
+ for f in self.pipeline:
+ f.reset_state_variables()
+
+
class Conv1dConnection(AbstractConnection):
# language=rst
"""
diff --git a/bindsnet/network/topology_features.py b/bindsnet/network/topology_features.py
new file mode 100644
index 00000000..f99cf39f
--- /dev/null
+++ b/bindsnet/network/topology_features.py
@@ -0,0 +1,940 @@
+from abc import ABC, abstractmethod
+from bindsnet.learning.learning import NoOp
+from typing import Union, Tuple, Optional, Sequence
+
+import numpy as np
+import torch
+from torch import device
+from torch.nn import Parameter
+import torch.nn.functional as F
+import torch.nn as nn
+import bindsnet.learning
+
+
+class AbstractFeature(ABC):
+ # language=rst
+ """
+ Features to operate on signals traversing a connection.
+ """
+
+ @abstractmethod
+ def __init__(
+ self,
+ name: str,
+ value: Union[torch.Tensor, float, int] = None,
+ range: Optional[Union[list, tuple]] = None,
+ clamp_frequency: Optional[int] = 1,
+ norm: Optional[Union[torch.Tensor, float, int]] = None,
+ learning_rule: Optional[bindsnet.learning.LearningRule] = None,
+ nu: Optional[Union[list, tuple, int, float]] = None,
+ reduction: Optional[callable] = None,
+ enforce_polarity: Optional[bool] = False,
+ decay: float = 0.0,
+ parent_feature=None,
+ **kwargs,
+ ) -> None:
+ # language=rst
+ """
+ Instantiates a :code:`Feature` object. Will assign all incoming arguments as class variables
+ :param name: Name of the feature
+ :param value: Core numeric object for the feature. This parameters function will vary depending on the feature
+ :param range: Range of acceptable values for the :code:`value` parameter
+ :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each
+ sample and after the value has been updated by the learning rule (if there is one)
+ :param learning_rule: Rule which will modify the :code:`value` after each sample
+ :param nu: Learning rate for the learning rule
+ :param reduction: Method for reducing parameter updates along the minibatch
+ dimension
+ :param decay: Constant multiple to decay weights by on each iteration
+ :param parent_feature: Parent feature to inherit :code:`value` from
+ """
+
+ #### Initialize class variables ####
+ ## Args ##
+ self.name = name
+ self.value = value
+ self.range = [-1.0, 1.0] if range is None else range
+ self.clamp_frequency = clamp_frequency
+ self.norm = norm
+ self.learning_rule = learning_rule
+ self.nu = nu
+ self.reduction = reduction
+ self.decay = decay
+ self.parent_feature = parent_feature
+ self.kwargs = kwargs
+
+ ## Backend ##
+ self.is_primed = False
+
+ from ..learning.MCC_learning import (
+ NoOp,
+ PostPre,
+ MSTDP,
+ MSTDPET,
+ )
+
+ supported_rules = [
+ NoOp,
+ PostPre,
+ MSTDP,
+ MSTDPET,
+ ]
+
+ #### Assertions ####
+ # Assert correct instance of feature values
+ assert isinstance(name, str), "Feature {0}'s name should be of type str".format(
+ name
+ )
+ assert value is None or isinstance(
+ value, (torch.Tensor, float, int)
+ ), "Feature {0} should be of type float, int, or torch.Tensor, not {1}".format(
+ name, type(value)
+ )
+ assert norm is None or isinstance(
+ norm, (torch.Tensor, float, int)
+ ), "Feature {0}'s norm should be of type float, int, or torch.Tensor, not {1}".format(
+ name, type(norm)
+ )
+ assert learning_rule is None or (
+ learning_rule in supported_rules
+ ), "Feature {0}'s learning_rule should be of type bindsnet.LearningRule not {1}".format(
+ name, type(learning_rule)
+ )
+ assert nu is None or isinstance(
+ nu, (list, tuple)
+ ), "Feature {0}'s nu should be of type list or tuple, not {1}".format(
+ name, type(nu)
+ )
+ assert reduction is None or isinstance(
+ reduction, callable
+ ), "Feature {0}'s reduction should be of type callable, not {1}".format(
+ name, type(reduction)
+ )
+ assert decay is None or isinstance(
+ decay, float
+ ), "Feature {0}'s decay should be of type float, not {1}".format(
+ name, type(decay)
+ )
+
+ self.assert_valid_range()
+ if value is not None:
+ self.assert_feature_in_range()
+
+ @abstractmethod
+ def reset_state_variables(self) -> None:
+ # language=rst
+ """
+ Contains resetting logic for the feature.
+ """
+ if self.learning_rule:
+ self.learning_rule.reset_state_variables()
+ pass
+
+ @abstractmethod
+ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
+ # language=rst
+ """
+ Computes the feature being operated on a set of incoming signals.
+ """
+ pass
+
+ def prime_feature(self, connection, device, **kwargs) -> None:
+ # language=rst
+ """
+ Prepares a feature after it has been placed in a connection. This takes care of learning rules, feature
+ value initialization, and asserting that features have proper shape. Should occur after primary constructor.
+ """
+
+ # Note: DO NOT move NoOp to global; cyclical dependency
+ from ..learning.MCC_learning import NoOp
+
+ # Check if feature is already primed
+ if self.is_primed:
+ return
+ self.is_primed = True
+
+ # Check if feature is a child feature
+ if self.parent_feature is not None:
+ self.link(self.parent_feature)
+ self.learning_rule = NoOp(connection=connection)
+ return
+
+ # Check if values/norms are the correct shape
+ if isinstance(self.value, torch.Tensor):
+ assert tuple(self.value.shape) == (connection.source.n, connection.target.n)
+
+ if self.norm is not None and isinstance(self.norm, torch.Tensor):
+ assert self.norm.shape[0] == connection.target.n
+
+ #### Initialize feature value ####
+ if self.value is None:
+ self.value = (
+ self.initialize_value()
+ ) # This should be defined per the sub-class
+
+ if isinstance(self.value, (int, float)):
+ self.value = torch.Tensor([self.value])
+
+ # Parameterize and send to proper device
+ # Note: Floating is used here to avoid dtype conflicts
+ self.value = Parameter(self.value, requires_grad=False).to(device)
+
+ ##### Initialize learning rule #####
+
+ # Default is NoOp
+ if self.learning_rule is None:
+ self.learning_rule = NoOp
+
+ self.learning_rule = self.learning_rule(
+ connection=connection,
+ feature_value=self.value,
+ range=self.range,
+ nu=self.nu,
+ reduction=self.reduction,
+ decay=self.decay,
+ **kwargs,
+ )
+
+ #### Recycle unnecessary variables ####
+ del self.nu, self.reduction, self.decay, self.range
+
+ def update(self, **kwargs) -> None:
+ # language=rst
+ """
+ Compute feature's update rule
+ """
+
+ self.learning_rule.update(**kwargs)
+
+ def normalize(self) -> None:
+ # language=rst
+ """
+ Normalize feature so each target neuron has sum of feature values equal to
+ ``self.norm``.
+ """
+
+ if self.norm is not None:
+ abs_sum = self.value.sum(0).unsqueeze(0)
+ abs_sum[abs_sum == 0] = 1.0
+ self.value *= self.norm / abs_sum
+
+ def degrade(self) -> None:
+ # language=rst
+ """
+ Degrade the value of the propagated spikes according to the features value. A lambda function should be passed
+ into the constructor which takes a single argument (which represent the value), and returns a value which will
+ be *subtracted* from the propagated spikes.
+ """
+
+ return self.degrade(self.value)
+
+ def link(self, parent_feature) -> None:
+ # language=rst
+ """
+ Allow two features to share tensor values
+ """
+
+ valid_features = (Probability, Weight, Bias, Intensity)
+
+ assert isinstance(self, valid_features), f"A {self} cannot use feature linking"
+ assert isinstance(
+ parent_feature, valid_features
+ ), f"A {parent_feature} cannot use feature linking"
+ assert self.is_primed, f"Prime feature before linking: {self}"
+ assert (
+ parent_feature.is_primed
+ ), f"Prime parent feature before linking: {parent_feature}"
+
+ # Link values, disable learning for this feature
+ self.value = parent_feature.value
+ self.learning_rule = NoOp
+
+ def assert_valid_range(self):
+ # language=rst
+ """
+ Default range verifier (within [-1, +1])
+ """
+
+ r = self.range
+
+ ## Check dtype ##
+ assert isinstance(
+ self.range, (list, tuple)
+ ), f"Invalid range for feature {self.name}: range should be a list or tuple, not {type(self.range)}"
+ assert (
+ len(r) == 2
+ ), f"Invalid range for feature {self.name}: range should have a length of 2"
+
+ ## Check min/max relation ##
+ if isinstance(r[0], torch.Tensor) or isinstance(r[1], torch.Tensor):
+ assert (
+ r[0] < r[1]
+ ).all(), f"Invalid range for feature {self.name}: a min is larger than an adjacent max"
+ else:
+ assert (
+ r[0] < r[1]
+ ), f"Invalid range for feature {self.name}: the min value is larger than the max value"
+
+ def assert_feature_in_range(self):
+ r = self.range
+ f = self.value
+
+ if isinstance(r[0], torch.Tensor) or isinstance(f, torch.Tensor):
+ assert (
+ f >= r[0]
+ ).all(), f"Feature out of range for {self.name}: Features values not in [{r[0]}, {r[1]}]"
+ else:
+ assert (
+ f >= r[0]
+ ), f"Feature out of range for {self.name}: Features values not in [{r[0]}, {r[1]}]"
+
+ if isinstance(r[1], torch.Tensor) or isinstance(f, torch.Tensor):
+ assert (
+ f <= r[1]
+ ).all(), f"Feature out of range for {self.name}: Features values not in [{r[0]}, {r[1]}]"
+ else:
+ assert (
+ f <= r[1]
+ ), f"Feature out of range for {self.name}: Features values not in [{r[0]}, {r[1]}]"
+
+ def assert_valid_shape(self, source_shape, target_shape, f):
+ # Multidimensional feat
+ if len(f.shape) > 1:
+ assert f.shape == (
+ source_shape,
+ target_shape,
+ ), f"Feature {self.name} has an incorrect shape of {f.shape}. Should be of shape {(source_shape, target_shape)}"
+ # Else assume scalar, which is a valid shape
+
+
+class Probability(AbstractFeature):
+ def __init__(
+ self,
+ name: str,
+ value: Union[torch.Tensor, float, int] = None,
+ range: Optional[Sequence[float]] = None,
+ norm: Optional[Union[torch.Tensor, float, int]] = None,
+ learning_rule: Optional[bindsnet.learning.LearningRule] = None,
+ nu: Optional[Union[list, tuple]] = None,
+ reduction: Optional[callable] = None,
+ decay: float = 0.0,
+ parent_feature=None,
+ ) -> None:
+ # language=rst
+ """
+ Will run a bernoulli trial using :code:`value` to determine if a signal will successfully traverse the synapse
+ :param name: Name of the feature
+ :param value: Number(s) in [0, 1] which represent the probability of a signal traversing a synapse. Tensor values
+ assume that probabilities will be matched to adjacent synapses in the connection. Scalars will be applied to
+ all synapses.
+ :param range: Range of acceptable values for the :code:`value` parameter. Should be in [0, 1]
+ :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample
+ and after the value has been updated by the learning rule (if there is one)
+ :param learning_rule: Rule which will modify the :code:`value` after each sample
+ :param nu: Learning rate for the learning rule
+ :param reduction: Method for reducing parameter updates along the minibatch
+ dimension
+ :param decay: Constant multiple to decay weights by on each iteration
+ :param parent_feature: Parent feature to inherit :code:`value` from
+ """
+
+ ### Assertions ###
+ super().__init__(
+ name=name,
+ value=value,
+ range=[0, 1] if range is None else range,
+ norm=norm,
+ learning_rule=learning_rule,
+ nu=nu,
+ reduction=reduction,
+ decay=decay,
+ parent_feature=parent_feature,
+ )
+
+ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
+ return conn_spikes * torch.bernoulli(self.value)
+
+ def reset_state_variables(self) -> None:
+ pass
+
+ def prime_feature(self, connection, device, **kwargs) -> None:
+ ## Initialize value ###
+ if self.value is None:
+ self.initialize_value = lambda: torch.clamp(
+ torch.rand(connection.source.n, connection.target.n, device=device),
+ self.range[0],
+ self.range[1],
+ )
+
+ super().prime_feature(connection, device, **kwargs)
+
+ def assert_valid_range(self):
+ super().assert_valid_range()
+
+ r = self.range
+
+ ## Check min greater than 0 ##
+ if isinstance(r[0], torch.Tensor):
+ assert (
+ r[0] >= 0
+ ).all(), (
+ f"Invalid range for feature {self.name}: a min value is less than 0"
+ )
+ elif isinstance(r[0], (float, int)):
+ assert (
+ r[0] >= 0
+ ), f"Invalid range for feature {self.name}: the min value is less than 0"
+ else:
+ assert (
+ False
+ ), f"Invalid range for feature {self.name}: the min value must be of type torch.Tensor, float, or int"
+
+
+class Mask(AbstractFeature):
+ def __init__(
+ self,
+ name: str,
+ value: Union[torch.Tensor, float, int] = None,
+ ) -> None:
+ # language=rst
+ """
+ Boolean mask which determines whether or not signals are allowed to traverse certain synapses.
+ :param name: Name of the feature
+ :param value: Boolean mask. :code:`True` means a signal can pass, :code:`False` means the synapse is impassable
+ """
+
+ ### Assertions ###
+ if isinstance(value, torch.Tensor):
+ assert (
+ value.dtype == torch.bool
+ ), "Mask must be of type bool, not {}".format(value.dtype)
+ elif value is not None:
+ assert isinstance(value, bool), "Mask must be of type bool, not {}".format(
+ value.dtype
+ )
+
+ # Send boolean to tensor (priming wont work if it's not a tensor)
+ value = torch.tensor(value)
+
+ super().__init__(
+ name=name,
+ value=value,
+ )
+
+ self.name = name
+ self.value = value
+
+ def compute(self, conn_spikes) -> torch.Tensor:
+ return conn_spikes * self.value
+
+ def reset_state_variables(self) -> None:
+ pass
+
+ def prime_feature(self, connection, device, **kwargs) -> None:
+ # Check if feature is already primed
+ if self.is_primed:
+ return
+ self.is_primed = True
+
+ #### Initialize feature value ####
+ if self.value is None:
+ self.value = (
+ torch.rand(connection.source.n, connection.target.n) > 0.99
+ ).to(device=device)
+ self.value = Parameter(self.value, requires_grad=False).to(device)
+
+ #### Assertions ####
+ # Check if tensor values are the correct shape
+ if isinstance(self.value, torch.Tensor):
+ self.assert_valid_shape(
+ connection.source.n, connection.target.n, self.value
+ )
+
+ ##### Initialize learning rule #####
+ # Note: DO NOT move NoOp to global; cyclical dependency
+ from ..learning.MCC_learning import NoOp
+
+ # Default is NoOp
+ if self.learning_rule is None:
+ self.learning_rule = NoOp
+
+ self.learning_rule = self.learning_rule(
+ connection=connection,
+ feature=self.value,
+ range=self.range,
+ nu=self.nu,
+ reduction=self.reduction,
+ decay=self.decay,
+ **kwargs,
+ )
+
+
+class MeanField(AbstractFeature):
+ def __init__(self) -> None:
+ # language=rst
+ """
+ Takes the mean of all outgoing signals, and outputs that mean across every synapse in the connection
+ """
+ pass
+
+ def reset_state_variables(self) -> None:
+ pass
+
+ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
+ return conn_spikes.mean() * torch.ones(
+ self.source_n * self.target_n, device=self.device
+ )
+
+ def prime_feature(self, connection, device, **kwargs) -> None:
+ self.source_n = connection.source.n
+ self.target_n = connection.target.n
+
+ super().prime_feature(connection, device, **kwargs)
+
+
+class Weight(AbstractFeature):
+ def __init__(
+ self,
+ name: str,
+ value: Union[torch.Tensor, float, int] = None,
+ range: Optional[Sequence[float]] = None,
+ norm: Optional[Union[torch.Tensor, float, int]] = None,
+ norm_frequency: Optional[str] = "sample",
+ learning_rule: Optional[bindsnet.learning.LearningRule] = None,
+ nu: Optional[Union[list, tuple]] = None,
+ reduction: Optional[callable] = None,
+ enforce_polarity: Optional[bool] = False,
+ decay: float = 0.0,
+ ) -> None:
+ # language=rst
+ """
+ Multiplies signals by scalars
+ :param name: Name of the feature
+ :param value: Values to scale signals by
+ :param range: Range of acceptable values for the :code:`value` parameter
+ :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample
+ and after the value has been updated by the learning rule (if there is one)
+ :param norm_frequency: How often to normalize weights:
+ * 'sample': weights normalized after each sample
+ * 'time step': weights normalized after each time step
+ :param learning_rule: Rule which will modify the :code:`value` after each sample
+ :param nu: Learning rate for the learning rule
+ :param reduction: Method for reducing parameter updates along the minibatch
+ dimension
+ :param enforce_polarity: Will prevent synapses from changing signs if :code:`True`
+ :param decay: Constant multiple to decay weights by on each iteration
+ """
+
+ self.norm_frequency = norm_frequency
+ self.enforce_polarity = enforce_polarity
+ super().__init__(
+ name=name,
+ value=value,
+ range=[-torch.inf, +torch.inf] if range is None else range,
+ norm=norm,
+ learning_rule=learning_rule,
+ nu=nu,
+ reduction=reduction,
+ decay=decay,
+ )
+
+ def reset_state_variables(self) -> None:
+ pass
+
+ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
+ if self.enforce_polarity:
+ pos_mask = ~torch.logical_xor(self.value > 0, self.positive_mask)
+ neg_mask = ~torch.logical_xor(self.value < 0, ~self.positive_mask)
+ self.value = self.value * torch.logical_or(pos_mask, neg_mask)
+ self.value[~pos_mask] = 0.0001
+ self.value[~neg_mask] = -0.0001
+
+ return_val = self.value * conn_spikes
+ if self.norm_frequency == "time step":
+ self.normalize(time_step_norm=True)
+
+ return return_val
+
+ def prime_feature(self, connection, device, **kwargs) -> None:
+ #### Initialize value ####
+ if self.value is None:
+ self.initialize_value = lambda: torch.rand(
+ connection.source.n, connection.target.n
+ )
+
+ super().prime_feature(
+ connection, device, enforce_polarity=self.enforce_polarity, **kwargs
+ )
+ if self.enforce_polarity:
+ self.positive_mask = ((self.value > 0).sum(1) / self.value.shape[1]) > 0.5
+ tmp = torch.zeros_like(self.value)
+ tmp[self.positive_mask, :] = 1
+ self.positive_mask = tmp.bool()
+
+ def normalize(self, time_step_norm=False) -> None:
+ # 'time_step_norm' will indicate if normalize is being called from compute()
+ # or from network.py (after a sample is completed)
+
+ if self.norm_frequency == "time step" and time_step_norm:
+ super().normalize()
+
+ if self.norm_frequency == "sample" and not time_step_norm:
+ super().normalize()
+
+
+class Bias(AbstractFeature):
+ def __init__(
+ self,
+ name: str,
+ value: Union[torch.Tensor, float, int] = None,
+ range: Optional[Sequence[float]] = None,
+ norm: Optional[Union[torch.Tensor, float, int]] = None,
+ ) -> None:
+ # language=rst
+ """
+ Adds scalars to signals
+ :param name: Name of the feature
+ :param value: Values to add to the signals
+ :param range: Range of acceptable values for the :code:`value` parameter
+ :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample
+ and after the value has been updated by the learning rule (if there is one)
+ """
+
+ super().__init__(
+ name=name,
+ value=value,
+ range=[-torch.inf, +torch.inf] if range is None else range,
+ norm=norm,
+ )
+
+ def reset_state_variables(self) -> None:
+ pass
+
+ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
+ return conn_spikes + self.value
+
+ def prime_feature(self, connection, device, **kwargs) -> None:
+ #### Initialize value ####
+ if self.value is None:
+ self.initialize_value = lambda: torch.rand(
+ connection.source.n, connection.target.n
+ )
+
+ super().prime_feature(connection, device, **kwargs)
+
+
+class Intensity(AbstractFeature):
+ def __init__(
+ self,
+ name: str,
+ value: Union[torch.Tensor, float, int] = None,
+ range: Optional[Sequence[float]] = None,
+ ) -> None:
+ # language=rst
+ """
+ Adds scalars to signals
+ :param name: Name of the feature
+ :param value: Values to scale signals by
+ """
+
+ super().__init__(name=name, value=value, range=range)
+
+ def reset_state_variables(self) -> None:
+ pass
+
+ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
+ return conn_spikes * self.value
+
+ def prime_feature(self, connection, device, **kwargs) -> None:
+ #### Initialize value ####
+ if self.value is None:
+ self.initialize_value = lambda: torch.clamp(
+ torch.sign(
+ torch.randint(-1, +2, (connection.source.n, connection.target.n))
+ ),
+ self.range[0],
+ self.range[1],
+ )
+
+ super().prime_feature(connection, device, **kwargs)
+
+
+class Degradation(AbstractFeature):
+ def __init__(
+ self,
+ name: str,
+ value: Union[torch.Tensor, float, int] = None,
+ degrade_function: callable = None,
+ parent_feature: Optional[AbstractFeature] = None,
+ ) -> None:
+ # language=rst
+ """
+ Degrades propagating spikes according to :code:`degrade_function`.
+ Note: If :code:`parent_feature` is provided, it will override :code:`value`.
+ :param name: Name of the feature
+ :param value: Value used to degrade feature
+ :param degrade_function: Callable function which takes a single argument (:code:`value`) and returns a tensor or
+ constant to be *subtracted* from the propagating spikes.
+ :param parent_feature: Parent feature with desired :code:`value` to inherit
+ """
+
+ # Note: parent_feature will override value. See abstract constructor
+ super().__init__(name=name, value=value, parent_feature=parent_feature)
+
+ self.degrade_function = degrade_function
+
+ def reset_state_variables(self) -> None:
+ pass
+
+ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
+ return conn_spikes - self.degrade_function(self.value)
+
+
+class AdaptationBaseSynapsHistory(AbstractFeature):
+ def __init__(
+ self,
+ name: str,
+ value: Union[torch.Tensor, float, int] = None,
+ ann_values: Union[list, tuple] = None,
+ const_update_rate: float = 0.1,
+ const_decay: float = 0.001,
+ ) -> None:
+ # language=rst
+ """
+ The ANN will be use on each synaps to messure the previous activity of the neuron and descide to close or open connection.
+
+ :param name: Name of the feature
+ :param ann_values: Values to be use to build an ANN that will adapt the connectivity of the layer.
+ :param value: Values to be use to build an initial mask for the synapses.
+ :param const_update_rate: The mask upatate rate of the ANN decision.
+ :param const_decay: The spontaneous activation of the synapses.
+ """
+
+ # Define the ANN
+ class ANN(nn.Module):
+ def __init__(self, input_size, hidden_size, output_size):
+ super(ANN, self).__init__()
+ self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
+ self.fc2 = nn.Linear(hidden_size, output_size, bias=False)
+
+ def forward(self, x):
+ x = torch.relu(self.fc1(x))
+ x = torch.tanh(self.fc2(x)) # MUST HAVE output between -1 and 1
+ return x
+
+ self.init_value = value.clone().detach() # initial mask
+ self.mask = value # final decision of the ANN
+ value = torch.zeros_like(value) # initial mask
+ self.ann = ANN(ann_values[0].shape[0], ann_values[0].shape[1], 1)
+
+ # load weights from ann_values
+ with torch.no_grad():
+ self.ann.fc1.weight.data = ann_values[0]
+ self.ann.fc2.weight.data = ann_values[1]
+ self.ann.to(ann_values[0].device)
+
+ self.spike_buffer = torch.zeros(
+ (value.numel(), ann_values[0].shape[1]),
+ device=ann_values[0].device,
+ dtype=torch.bool,
+ )
+ self.counter = 0
+ self.start_counter = False
+ self.const_update_rate = const_update_rate
+ self.const_decay = const_decay
+
+ super().__init__(name=name, value=value)
+
+ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
+
+ # Update the spike buffer
+ if self.start_counter == False or conn_spikes.sum() > 0:
+ self.start_counter = True
+ self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = (
+ conn_spikes.flatten()
+ )
+ self.counter += 1
+
+ # Update the masks
+ if self.counter % self.spike_buffer.shape[1] == 0:
+ with torch.no_grad():
+ ann_decision = self.ann(self.spike_buffer.to(torch.float32))
+ self.mask += (
+ ann_decision.view(self.mask.shape) * self.const_update_rate
+ ) # update mask with learning rate fraction
+ self.mask += self.const_decay # spontaneous activate synapses
+ self.mask = torch.clamp(self.mask, -1, 1) # cap the mask
+
+ # self.mask = torch.clamp(self.mask, -1, 1)
+ self.value = (self.mask > 0).float()
+
+ return conn_spikes * self.value
+
+ def reset_state_variables(
+ self,
+ ):
+ self.spike_buffer = torch.zeros_like(self.spike_buffer)
+ self.counter = 0
+ self.start_counter = False
+ self.value = self.init_value.clone().detach() # initial mask
+ pass
+
+
+class AdaptationBaseOtherSynaps(AbstractFeature):
+ def __init__(
+ self,
+ name: str,
+ value: Union[torch.Tensor, float, int] = None,
+ ann_values: Union[list, tuple] = None,
+ const_update_rate: float = 0.1,
+ const_decay: float = 0.01,
+ ) -> None:
+ # language=rst
+ """
+ The ANN will be use on each synaps to messure the previous activity of the neuron and descide to close or open connection.
+
+ :param name: Name of the feature
+ :param ann_values: Values to be use to build an ANN that will adapt the connectivity of the layer.
+ :param value: Values to be use to build an initial mask for the synapses.
+ :param const_update_rate: The mask upatate rate of the ANN decision.
+ :param const_decay: The spontaneous activation of the synapses.
+ """
+
+ # Define the ANN
+ class ANN(nn.Module):
+ def __init__(self, input_size, hidden_size, output_size):
+ super(ANN, self).__init__()
+ self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
+ self.fc2 = nn.Linear(hidden_size, output_size, bias=False)
+
+ def forward(self, x):
+ x = torch.relu(self.fc1(x))
+ x = torch.tanh(self.fc2(x)) # MUST HAVE output between -1 and 1
+ return x
+
+ self.init_value = value.clone().detach() # initial mask
+ self.mask = value # final decision of the ANN
+ value = torch.zeros_like(value) # initial mask
+ self.ann = ANN(ann_values[0].shape[0], ann_values[0].shape[1], 1)
+
+ # load weights from ann_values
+ with torch.no_grad():
+ self.ann.fc1.weight.data = ann_values[0]
+ self.ann.fc2.weight.data = ann_values[1]
+ self.ann.to(ann_values[0].device)
+
+ self.spike_buffer = torch.zeros(
+ (value.numel(), ann_values[0].shape[1]),
+ device=ann_values[0].device,
+ dtype=torch.bool,
+ )
+ self.counter = 0
+ self.start_counter = False
+ self.const_update_rate = const_update_rate
+ self.const_decay = const_decay
+
+ super().__init__(name=name, value=value)
+
+ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
+
+ # Update the spike buffer
+ if self.start_counter == False or conn_spikes.sum() > 0:
+ self.start_counter = True
+ self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = (
+ conn_spikes.flatten()
+ )
+ self.counter += 1
+
+ # Update the masks
+ if self.counter % self.spike_buffer.shape[1] == 0:
+ with torch.no_grad():
+ ann_decision = self.ann(self.spike_buffer.to(torch.float32))
+ self.mask += (
+ ann_decision.view(self.mask.shape) * self.const_update_rate
+ ) # update mask with learning rate fraction
+ self.mask += self.const_decay # spontaneous activate synapses
+ self.mask = torch.clamp(self.mask, -1, 1) # cap the mask
+
+ # self.mask = torch.clamp(self.mask, -1, 1)
+ self.value = (self.mask > 0).float()
+
+ return conn_spikes * self.value
+
+ def reset_state_variables(
+ self,
+ ):
+ self.spike_buffer = torch.zeros_like(self.spike_buffer)
+ self.counter = 0
+ self.start_counter = False
+ self.value = self.init_value.clone().detach() # initial mask
+ pass
+
+
+### Sub Features ###
+
+
+class AbstractSubFeature(ABC):
+ # language=rst
+ """
+ A way to inject a features methods (like normalization, learning, etc.) into the pipeline for user controlled
+ execution.
+ """
+
+ @abstractmethod
+ def __init__(
+ self,
+ name: str,
+ parent_feature: AbstractFeature,
+ ) -> None:
+ # language=rst
+ """
+ Instantiates a :code:`Augment` object. Will assign all incoming arguments as class variables.
+ :param name: Name of the augment
+ :param parent_feature: Primary feature which the augment will modify
+ """
+
+ self.name = name
+ self.parent = parent_feature
+ self.sub_feature = None # <-- Defined in non-abstract constructor
+
+ def compute(self, _) -> None:
+ # language=rst
+ """
+ Proxy function to catch a pipeline execution from topology.py's :code:`compute` function. Allows :code:`SubFeature`
+ objects to be executed like real features in the pipeline.
+ """
+
+ # sub_feature should be defined in the non-abstract constructor
+ self.sub_feature()
+
+
+class Normalization(AbstractSubFeature):
+ # language=rst
+ """
+ Normalize parent features values so each target neuron has sum of feature values equal to a desired value :code:`norm`.
+ """
+
+ def __init__(
+ self,
+ name: str,
+ parent_feature: AbstractFeature,
+ ) -> None:
+ super().__init__(name, parent_feature)
+
+ self.sub_feature = self.parent.normalize
+
+
+class Updating(AbstractSubFeature):
+ # language=rst
+ """
+ Update parent features values using the assigned update rule.
+ """
+
+ def __init__(
+ self,
+ name: str,
+ parent_feature: AbstractFeature,
+ ) -> None:
+ super().__init__(name, parent_feature)
+
+ self.sub_feature = self.parent.update
diff --git a/docs/source/guide/guide_part_i.rst b/docs/source/guide/guide_part_i.rst
index ff425e8b..d8bc35d7 100644
--- a/docs/source/guide/guide_part_i.rst
+++ b/docs/source/guide/guide_part_i.rst
@@ -145,6 +145,33 @@ weights based on pre-, post-synaptic activity and possibly other signals; e.g.,
:code:`normalize` (for ensuring weights incident to post-synaptic neurons sum to a pre-specified value), and :code:`reset_state_variables`
(for re-initializing stateful variables for the start of a new simulation).
+For more complex connections, the MulticompartmentConnection class can be used. The MulticompartmentConnection will pass spikes through different "features"
+such as weights, bias's, and boolean masks in a specified order. Features are passed to the MulticompartmentConnection constructor in a list, and executed in order.
+For example, the code below uses a pipeline containing a weight and bias feature. During runtime, spikes from the source will be multiplied by the weights first,
+then a bias added second. Additional features can be added before/after/between these two.
+To create a simple all-to-all connection with a weight and bias:
+
+.. code-block:: python
+
+ from bindsnet.network.nodes import Input, LIFNodes
+ from bindsnet.network.topology import MulticompartmentConnection
+ from bindsnet.network.topology_features import Weight, Bias
+
+ # Create two populations of neurons, one to act as the "source"
+ # population, and the other, the "target population".
+ source_layer = Input(n=100)
+ target_layer = LIFNodes(n=1000)
+
+ # Create 'pipeline' of features that spikes will pass through
+ weights = Weight(name='weight_feature', value=torch.rand(100, 1000))
+ bias = Bias(name='bias_feature', value=torch.rand(100, 1000))
+
+ # Connect the two layers.
+ connection = MulticompartmentConnection(
+ source=source_layer, target=target_layer,
+ pipeline=[weight, bias]
+ )
+
Specifying monitors
*******************
@@ -176,7 +203,9 @@ course of simulation in certain network components. To create a monitor to monit
The user must specify a :code:`Nodes` or :code:`AbstractConnection` object from which to record, attributes of that
object to record (:code:`state_vars`), and, optionally, how many time steps the simulation(s) will last, in order to
-save time by pre-allocating memory.
+save time by pre-allocating memory.
+
+Monitors are not officially supported for MulticompartmentConnection
To add a monitor to the network (thereby enabling monitoring), use the :code:`add_monitor` function of the
:py:class:`bindsnet.network.Network` class:
diff --git a/docs/source/guide/guide_part_ii.rst b/docs/source/guide/guide_part_ii.rst
index 13ef5b31..786753a3 100644
--- a/docs/source/guide/guide_part_ii.rst
+++ b/docs/source/guide/guide_part_ii.rst
@@ -74,4 +74,33 @@ Custom learning rules can be implemented by subclassing :code:`bindsnet.learning
and providing implementations for the types of :code:`AbstractConnection` objects intended to be used.
For example, the :code:`Connection` and :code:`LocalConnection` objects rely on the implementation
of a private method, :code:`_connection_update`, whereas the :code:`Conv2dConnection` object
-uses the :code:`_conv2d_connection_update` version.
\ No newline at end of file
+uses the :code:`_conv2d_connection_update` version.
+
+If using a MulticompartmentConneciton, you can add a learning rule to a specific feature. Note that only
+:code:`NoOp`, :code:`PostPre`, :code:`MSTDP`, :code:`MSTDPET` are supported, and located at
+bindsnet.learning.MCC_learning. Below is an example of how to apply a PostPre learning rule to a weight function.
+Note that the bias does not have a learning rule, so it will remain static.
+
+.. code-block:: python
+
+ from bindsnet.network.nodes import Input, LIFNodes
+ from bindsnet.network.topology import MulticompartmentConnection
+ from bindsnet.learning.MCC_learning import PostPre
+
+ # Create two populations of neurons, one to act as the "source"
+ # population, and the other, the "target population".
+ # Neurons involved in certain learning rules must record synaptic
+ # traces, a vector of short-term memories of the last emitted spikes.
+ source_layer = Input(n=100, traces=True)
+ target_layer = LIFNodes(n=1000, traces=True)
+
+ # Create 'pipeline' of features that spikes will pass through
+ weights = Weight(name='weight_feature', value=torch.rand(100, 1000),
+ learning_rule=PostPre, nu=(1e-4, 1e-2))
+ bias = Bias(name='bias_feature', value=torch.rand(100, 1000))
+
+ # Connect the two layers.
+ connection = MulticompartmentConnection(
+ source=source_layer, target=target_layer,
+ pipeline=[weights, bias])
+ )
\ No newline at end of file
diff --git a/examples/mnist/MCC_reservoir.py b/examples/mnist/MCC_reservoir.py
new file mode 100644
index 00000000..89145d6f
--- /dev/null
+++ b/examples/mnist/MCC_reservoir.py
@@ -0,0 +1,309 @@
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+import argparse
+import matplotlib.pyplot as plt
+
+from torchvision import transforms
+from tqdm import tqdm
+
+from bindsnet.analysis.plotting import (
+ plot_input,
+ plot_spikes,
+ plot_voltages,
+ plot_weights,
+)
+from bindsnet.datasets import MNIST
+from bindsnet.encoding import PoissonEncoder
+from bindsnet.network import Network
+from bindsnet.network.nodes import Input
+from bindsnet.network.topology_features import Probability, Weight, Mask
+
+# Build a simple two-layer, input-output network.
+from bindsnet.network.monitors import Monitor
+from bindsnet.network.nodes import LIFNodes
+from bindsnet.network.topology import MulticompartmentConnection
+from bindsnet.utils import get_square_weights
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--seed", type=int, default=0)
+parser.add_argument("--n_neurons", type=int, default=500)
+parser.add_argument("--n_epochs", type=int, default=100)
+parser.add_argument("--examples", type=int, default=500)
+parser.add_argument("--n_workers", type=int, default=-1)
+parser.add_argument("--time", type=int, default=250)
+parser.add_argument("--dt", type=int, default=1.0)
+parser.add_argument("--intensity", type=float, default=64)
+parser.add_argument("--progress_interval", type=int, default=10)
+parser.add_argument("--update_interval", type=int, default=250)
+parser.add_argument("--plot", dest="plot", action="store_true")
+parser.add_argument("--gpu", dest="gpu", action="store_true")
+parser.set_defaults(plot=False, gpu=True, train=True)
+
+args = parser.parse_args()
+
+seed = args.seed
+n_neurons = args.n_neurons
+n_epochs = args.n_epochs
+examples = args.examples
+n_workers = args.n_workers
+time = args.time
+dt = args.dt
+intensity = args.intensity
+progress_interval = args.progress_interval
+update_interval = args.update_interval
+train = args.train
+plot = args.plot
+gpu = args.gpu
+
+np.random.seed(seed)
+torch.cuda.manual_seed_all(seed)
+torch.manual_seed(seed)
+
+# Sets up Gpu use
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+if gpu and torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+else:
+ torch.manual_seed(seed)
+ device = "cpu"
+ if gpu:
+ gpu = False
+torch.set_num_threads(os.cpu_count() - 1)
+print("Running on Device = ", device)
+
+
+### Base model ###
+model = Network()
+model.to(device)
+
+
+### Layers ###
+input_l = Input(n=784, shape=(1, 28, 28), traces=True)
+output_l = LIFNodes(
+ n=n_neurons, thresh=-52 + np.random.randn(n_neurons).astype(float), traces=True
+)
+
+model.add_layer(input_l, name="X")
+model.add_layer(output_l, name="Y")
+
+
+### Connections ###
+p = torch.rand(input_l.n, output_l.n)
+d = torch.rand(input_l.n, output_l.n) / 5
+w = torch.sign(torch.randint(-1, +2, (input_l.n, output_l.n)))
+prob_feature = Probability(name="input_prob_feature", value=p)
+weight_feature = Weight(name="input_weight_feature", value=w)
+pipeline = [prob_feature, weight_feature]
+input_con = MulticompartmentConnection(
+ source=input_l,
+ target=output_l,
+ device=device,
+ pipeline=pipeline,
+)
+
+p = torch.rand(output_l.n, output_l.n)
+d = torch.rand(output_l.n, output_l.n) / 5
+w = torch.sign(torch.randint(-1, +2, (output_l.n, output_l.n)))
+prob_feature = Probability(name="recc_prob_feature", value=p)
+weight_feature = Weight(name="recc_weight_feature", value=w)
+pipeline = [prob_feature, weight_feature]
+recurrent_con = MulticompartmentConnection(
+ source=output_l,
+ target=output_l,
+ device=device,
+ pipeline=pipeline,
+)
+
+model.add_connection(input_con, source="X", target="Y")
+model.add_connection(recurrent_con, source="Y", target="Y")
+
+# Directs network to GPU
+if gpu:
+ model.to("cuda")
+
+### MNIST ###
+dataset = MNIST(
+ PoissonEncoder(time=time, dt=dt),
+ None,
+ root=os.path.join("../../test", "..", "data", "MNIST"),
+ download=True,
+ transform=transforms.Compose(
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
+ ),
+)
+
+
+### Monitor setup ###
+inpt_axes = None
+inpt_ims = None
+spike_axes = None
+spike_ims = None
+weights_im = None
+weights_im2 = None
+voltage_ims = None
+voltage_axes = None
+spikes = {}
+voltages = {}
+for l in model.layers:
+ spikes[l] = Monitor(model.layers[l], ["s"], time=time, device=device)
+ model.add_monitor(spikes[l], name="%s_spikes" % l)
+
+voltages = {"Y": Monitor(model.layers["Y"], ["v"], time=time, device=device)}
+model.add_monitor(voltages["Y"], name="Y_voltages")
+
+
+### Running model on MNIST ###
+
+# Create a dataloader to iterate and batch data
+dataloader = torch.utils.data.DataLoader(
+ dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True
+)
+
+n_iters = examples
+training_pairs = []
+pbar = tqdm(enumerate(dataloader))
+for i, dataPoint in pbar:
+ if i > n_iters:
+ break
+
+ # Extract & resize the MNIST samples image data for training
+ # int(time / dt) -> length of spike train
+ # 28 x 28 -> size of sample
+ datum = dataPoint["encoded_image"].view(int(time / dt), 1, 1, 28, 28).to(device)
+ label = dataPoint["label"]
+ pbar.set_description_str("Train progress: (%d / %d)" % (i, n_iters))
+
+ # Run network on sample image
+ model.run(inputs={"X": datum}, time=time, input_time_dim=1, reward=1.0)
+ training_pairs.append([spikes["Y"].get("s").sum(0), label])
+
+ # Plot spiking activity using monitors
+ if plot:
+ inpt_axes, inpt_ims = plot_input(
+ dataPoint["image"].view(28, 28),
+ datum.view(int(time / dt), 784).sum(0).view(28, 28),
+ label=label,
+ axes=inpt_axes,
+ ims=inpt_ims,
+ )
+ spike_ims, spike_axes = plot_spikes(
+ {layer: spikes[layer].get("s").view(time, -1) for layer in spikes},
+ axes=spike_axes,
+ ims=spike_ims,
+ )
+ voltage_ims, voltage_axes = plot_voltages(
+ {layer: voltages[layer].get("v").view(time, -1) for layer in voltages},
+ ims=voltage_ims,
+ axes=voltage_axes,
+ )
+
+ plt.pause(1e-8)
+ model.reset_state_variables()
+
+
+### Classification ###
+# Define logistic regression model using PyTorch.
+# These neurons will take the reservoirs output as its input, and be trained to classify the images.
+class NN(nn.Module):
+ def __init__(self, input_size, num_classes):
+ super(NN, self).__init__()
+ # h = int(input_size/2)
+ self.linear_1 = nn.Linear(input_size, num_classes)
+ # self.linear_1 = nn.Linear(input_size, h)
+ # self.linear_2 = nn.Linear(h, num_classes)
+
+ def forward(self, x):
+ out = torch.sigmoid(self.linear_1(x.float().view(-1)))
+ # out = torch.sigmoid(self.linear_2(out))
+ return out
+
+
+# Create and train logistic regression model on reservoir outputs.
+learning_model = NN(n_neurons, 10).to(device)
+criterion = torch.nn.MSELoss(reduction="sum")
+optimizer = torch.optim.SGD(learning_model.parameters(), lr=1e-4, momentum=0.9)
+
+# Training the Model
+print("\n Training the read out")
+pbar = tqdm(enumerate(range(n_epochs)))
+for epoch, _ in pbar:
+ avg_loss = 0
+
+ # Extract spike outputs from reservoir for a training sample
+ # i -> Loop index
+ # s -> Reservoir output spikes
+ # l -> Image label
+ for i, (s, l) in enumerate(training_pairs):
+ # Reset gradients to 0
+ optimizer.zero_grad()
+
+ # Run spikes through logistic regression model
+ outputs = learning_model(s)
+
+ # Calculate MSE
+ label = torch.zeros(1, 1, 10).float().to(device)
+ label[0, 0, l] = 1.0
+ loss = criterion(outputs.view(1, 1, -1), label)
+ avg_loss += loss.data
+
+ # Optimize parameters
+ loss.backward()
+ optimizer.step()
+
+ pbar.set_description_str(
+ "Epoch: %d/%d, Loss: %.4f"
+ % (epoch + 1, n_epochs, avg_loss / len(training_pairs))
+ )
+
+# Run same simulation on reservoir with testing data instead of training data
+# (see training section for intuition)
+n_iters = examples
+test_pairs = []
+pbar = tqdm(enumerate(dataloader))
+for i, dataPoint in pbar:
+ if i > n_iters:
+ break
+ datum = dataPoint["encoded_image"].view(int(time / dt), 1, 1, 28, 28).to(device)
+ label = dataPoint["label"]
+ pbar.set_description_str("Testing progress: (%d / %d)" % (i, n_iters))
+
+ model.run(inputs={"X": datum}, time=time, input_time_dim=1)
+ test_pairs.append([spikes["Y"].get("s").sum(0), label])
+
+ if plot:
+ inpt_axes, inpt_ims = plot_input(
+ dataPoint["image"].view(28, 28),
+ datum.view(time, 784).sum(0).view(28, 28),
+ label=label,
+ axes=inpt_axes,
+ ims=inpt_ims,
+ )
+ spike_ims, spike_axes = plot_spikes(
+ {layer: spikes[layer].get("s").view(time, -1) for layer in spikes},
+ axes=spike_axes,
+ ims=spike_ims,
+ )
+ voltage_ims, voltage_axes = plot_voltages(
+ {layer: voltages[layer].get("v").view(time, -1) for layer in voltages},
+ ims=voltage_ims,
+ axes=voltage_axes,
+ )
+
+ plt.pause(1e-8)
+ model.reset_state_variables()
+
+# Test learning model with previously trained logistic regression classifier
+correct, total = 0, 0
+for s, label in test_pairs:
+ outputs = learning_model(s)
+ _, predicted = torch.max(outputs.data.unsqueeze(0), 1)
+ total += 1
+ correct += int(predicted == label.long().to(device))
+
+print(
+ "\n Accuracy of the model on %d test images: %.2f %%"
+ % (n_iters, 100 * correct / total)
+)
diff --git a/examples/mnist/reservoir.py b/examples/mnist/reservoir.py
index e61b17cc..9946fc32 100644
--- a/examples/mnist/reservoir.py
+++ b/examples/mnist/reservoir.py
@@ -101,7 +101,7 @@
dataset = MNIST(
PoissonEncoder(time=time, dt=dt),
None,
- root=os.path.join("..", "data", "MNIST"),
+ root=os.path.join("..", "..", "data", "MNIST"),
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]