diff --git a/README.md b/README.md index f160b3ce..b6ce35f4 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,8 @@ A Python package used for simulating spiking neural networks (SNNs) on CPUs or G BindsNET is a spiking neural network simulation library geared towards the development of biologically inspired algorithms for machine learning. -This package is used as part of ongoing research on applying SNNs to machine learning (ML) and reinforcement learning (RL) problems in the [Biologically Inspired Neural & Dynamical Systems (BINDS) lab](http://binds.cs.umass.edu/). +This package is used as part of ongoing research on applying SNNs, machine learning (ML) and reinforcement learning (RL) problems in the [Biologically Inspired Neural & Dynamical Systems (BINDS) lab](http://binds.cs.umass.edu/) and the Allen Discovery Center at Tufts University. + Check out the [BindsNET examples](https://github.com/BindsNET/bindsnet/tree/master/examples) for a collection of experiments, functions for the analysis of results, plots of experiment outcomes, and more. Documentation for the package can be found [here](https://bindsnet-docs.readthedocs.io). @@ -129,11 +130,22 @@ If you use BindsNET in your research, please cite the following [article](https: ## Contributors -- Daniel Saunders ([email](mailto:djsaunde@cs.umass.edu)) -- Hananel Hazan ([email](mailto:hananel@hazan.org.il)) -- Darpan Sanghavi ([email](mailto:dsanghavi@cs.umass.edu)) -- Hassaan Khan ([email](mailto:hqkhan@umass.edu)) -- Devdhar Patel ([email](mailto:devdharpatel@cs.umass.edu)) +- Hava Siegelmann - Director of BINDS lab at UMass +- Robert Kozma - Co-Director of BINDS lab (2018-2019) +- Hananel Hazan ([email](mailto:hananel@hazan.org.il)) - Spearheaded BindsNET and its main maintainer. +- Daniel Saunders ([email](mailto:djsaunde@cs.umass.edu)) - MSc student, BindsNET core functions coder (2018-2019). +- Darpan Sanghavi ([email](mailto:dsanghavi@cs.umass.edu)) - MSc student BINDS Lab (2018) +- Hassaan Khan ([email](mailto:hqkhan@umass.edu)) - MSc student BINDS Lab (2018) +- Devdhar Patel ([email](mailto:devdharpatel@cs.umass.edu)) - MSc student BINDS Lab (2018) +- Simon Caby [github](https://github.com/SimonInParis) +- Christopher Earl ([email](mailto:cearl@umass.edu)) - MSc student (2021 - present) + + + + + + +Made with [contrib.rocks](https://contrib.rocks). ## License GNU Affero General Public License v3.0 diff --git a/bindsnet/analysis/plotting.py b/bindsnet/analysis/plotting.py index bb5a39c1..2c2ebc04 100644 --- a/bindsnet/analysis/plotting.py +++ b/bindsnet/analysis/plotting.py @@ -703,7 +703,7 @@ def plot_voltages( .numpy()[ time[0] : time[1], n_neurons[v[0]][0] : n_neurons[v[0]][1], - ] + ], ) ) diff --git a/bindsnet/learning/MCC_learning.py b/bindsnet/learning/MCC_learning.py new file mode 100644 index 00000000..14565a80 --- /dev/null +++ b/bindsnet/learning/MCC_learning.py @@ -0,0 +1,721 @@ +from abc import ABC, abstractmethod +from typing import Union, Optional, Sequence +import warnings + +import torch +import numpy as np + +from ..network.nodes import SRM0Nodes +from ..network.topology import ( + AbstractMulticompartmentConnection, + MulticompartmentConnection, +) +from ..utils import im2col_indices + + +class MCC_LearningRule(ABC): + # language=rst + """ + Abstract base class for learning rules. + """ + + def __init__( + self, + connection: AbstractMulticompartmentConnection, + feature_value: Union[float, int, torch.Tensor], + range: Optional[Union[list, tuple]] = None, + nu: Optional[Union[float, Sequence[float]]] = None, + reduction: Optional[callable] = None, + decay: float = 0.0, + enforce_polarity: bool = False, + **kwargs, + ) -> None: + # language=rst + """ + Abstract constructor for the ``LearningRule`` object. + + :param connection: An ``AbstractConnection`` object. + :param feature_value: Value(s) to be updated. Can be only tensor (scalar currently not supported) + :param range: Allowed range for :code:`feature_value` + :param nu: Single or pair of learning rates for pre- and post-synaptic events. + :param reduction: Method for reducing parameter updates along the batch + dimension. + :param decay: Coefficient controlling rate of decay of the weights each iteration. + :param enforce_polarity: Will prevent synapses from changing signs if :code:`True` + """ + # Connection parameters. + self.connection = connection + self.source = connection.source + self.target = connection.target + self.feature_value = feature_value + self.enforce_polarity = enforce_polarity + self.min, self.max = range + + # Learning rate(s). + if nu is None: + nu = [0.2, 0.1] + elif isinstance(nu, (float, int)): + nu = [nu, nu] + + # Keep track of polarities + if enforce_polarity: + self.polarities = torch.sign(self.feature_value) + + self.nu = torch.zeros(2, dtype=torch.float) + self.nu[0] = nu[0] + self.nu[1] = nu[1] + + if (self.nu == torch.zeros(2)).all() and not isinstance(self, NoOp): + warnings.warn( + f"nu is set to [0., 0.] for {type(self).__name__} learning rule. " + + "It will disable the learning process." + ) + + # Parameter update reduction across minibatch dimension. + if reduction is None: + if self.source.batch_size == 1: + self.reduction = torch.squeeze + else: + self.reduction = torch.sum + else: + self.reduction = reduction + + # Weight decay. + self.decay = 1.0 - decay if decay else 1.0 + + def update(self, **kwargs) -> None: + # language=rst + """ + Abstract method for a learning rule update. + """ + + # Implement decay. + if self.decay: + self.feature_value *= self.decay + + # Enforce polarities + if self.enforce_polarity: + polarity_swaps = self.polarities == torch.sign(self.feature_value) + self.feature_value[polarity_swaps == 0] = 0 + + # Bound weights. + if ((self.min is not None) or (self.max is not None)) and not isinstance( + self, NoOp + ): + self.feature_value.clamp_(self.min, self.max) + + @abstractmethod + def reset_state_variables(self) -> None: + # language=rst + """ + Contains resetting logic for the feature. + """ + pass + + +class NoOp(MCC_LearningRule): + # language=rst + """ + Learning rule with no effect. + """ + + def __init__(self, **args) -> None: + # language=rst + """ + No operation done during runtime + """ + pass + + def update(self, **kwargs) -> None: + # language=rst + """ + No operation done during runtime + """ + pass + + def reset_state_variables(self) -> None: + # language=rst + """ + Contains resetting logic for the feature. + """ + pass + + +class PostPre(MCC_LearningRule): + # language=rst + """ + Simple STDP rule involving both pre- and post-synaptic spiking activity. By default, + pre-synaptic update is negative and the post-synaptic update is positive. + """ + + def __init__( + self, + connection: AbstractMulticompartmentConnection, + feature_value: Union[torch.Tensor, float, int], + range: Optional[Sequence[float]] = None, + nu: Optional[Union[float, Sequence[float]]] = None, + reduction: Optional[callable] = None, + decay: float = 0.0, + enforce_polarity: bool = False, + **kwargs, + ) -> None: + # language=rst + """ + Constructor for ``PostPre`` learning rule. + + :param connection: An ``AbstractConnection`` object whose weights the + ``PostPre`` learning rule will modify. + :param feature_value: The object which will be altered + :param range: The domain for the feature + :param nu: Single or pair of learning rates for pre- and post-synaptic events. + :param reduction: Method for reducing parameter updates along the batch + dimension. + :param decay: Coefficient controlling rate of decay of the weights each iteration. + :param enforce_polarity: Will prevent synapses from changing signs if :code:`True` + + Keyword arguments: + :param average_update: Number of updates to average over, 0=No averaging, x=average over last x updates + :param continues_update: If True, the update will be applied after every update, if False, only after the average_update buffer is full + """ + super().__init__( + connection=connection, + feature_value=feature_value, + range=[-1, +1] if range is None else range, + nu=nu, + reduction=reduction, + decay=decay, + enforce_polarity=enforce_polarity, + **kwargs, + ) + + assert self.source.traces and self.target.traces, ( + "Both pre- and post-synaptic nodes must record spike traces " + "(use traces='True' on source/target layers)" + ) + + if isinstance(connection, (MulticompartmentConnection)): + self.update = self._connection_update + # elif isinstance(connection, Conv2dConnection): + # self.update = self._conv2d_connection_update + else: + raise NotImplementedError( + "This learning rule is not supported for this Connection type." + ) + + # Initialize variables for average update and continues update + self.average_update = kwargs.get("average_update", 0) + self.continues_update = kwargs.get("continues_update", False) + + if self.average_update > 0: + self.average_buffer_pre = torch.zeros( + self.average_update, + *self.feature_value.shape, + device=self.feature_value.device, + ) + self.average_buffer_post = torch.zeros_like(self.average_buffer_pre) + self.average_buffer_index_pre = 0 + self.average_buffer_index_post = 0 + + def _connection_update(self, **kwargs) -> None: + # language=rst + """ + Post-pre learning rule for ``Connection`` subclass of ``AbstractConnection`` + class. + """ + batch_size = self.source.batch_size + + # Pre-synaptic update. + if self.nu[0]: + source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float() + target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0] + + if self.average_update > 0: + self.average_buffer_pre[self.average_buffer_index_pre] = self.reduction( + torch.bmm(source_s, target_x), dim=0 + ) + + self.average_buffer_index_pre = ( + self.average_buffer_index_pre + 1 + ) % self.average_update + + if self.continues_update: + self.feature_value -= ( + torch.mean(self.average_buffer_pre, dim=0) * self.connection.dt + ) + elif self.average_buffer_index_pre == 0: + self.feature_value -= ( + torch.mean(self.average_buffer_pre, dim=0) * self.connection.dt + ) + else: + self.feature_value -= ( + self.reduction(torch.bmm(source_s, target_x), dim=0) + * self.connection.dt + ) + del source_s, target_x + + # Post-synaptic update. + if self.nu[1]: + target_s = ( + self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1] + ) + source_x = self.source.x.view(batch_size, -1).unsqueeze(2) + + if self.average_update > 0: + self.average_buffer_post[self.average_buffer_index_post] = ( + self.reduction(torch.bmm(source_x, target_s), dim=0) + ) + + self.average_buffer_index_post = ( + self.average_buffer_index_post + 1 + ) % self.average_update + + if self.continues_update: + self.feature_value += ( + torch.mean(self.average_buffer_post, dim=0) * self.connection.dt + ) + elif self.average_buffer_index_post == 0: + self.feature_value += ( + torch.mean(self.average_buffer_post, dim=0) * self.connection.dt + ) + else: + self.feature_value += ( + self.reduction(torch.bmm(source_x, target_s), dim=0) + * self.connection.dt + ) + del source_x, target_s + + super().update() + + def reset_state_variables(self): + return + + class Hebbian(MCC_LearningRule): + # language=rst + """ + Simple Hebbian learning rule. Pre- and post-synaptic updates are both positive. + """ + + def __init__( + self, + connection: AbstractMulticompartmentConnection, + feature_value: Union[torch.Tensor, float, int], + nu: Optional[Union[float, Sequence[float]]] = None, + reduction: Optional[callable] = None, + decay: float = 0.0, + **kwargs, + ) -> None: + # language=rst + """ + Constructor for ``Hebbian`` learning rule. + + :param connection: An ``AbstractConnection`` object whose weights the + ``Hebbian`` learning rule will modify. + :param nu: Single or pair of learning rates for pre- and post-synaptic events. + :param reduction: Method for reducing parameter updates along the batch + dimension. + :param decay: Coefficient controlling rate of decay of the weights each iteration. + """ + super().__init__( + connection=connection, + feature_value=feature_value, + nu=nu, + reduction=reduction, + decay=decay, + **kwargs, + ) + + assert ( + self.source.traces and self.target.traces + ), "Both pre- and post-synaptic nodes must record spike traces." + + if isinstance(MulticompartmentConnection): + self.update = self._connection_update + self.feature_value = feature_value + # elif isinstance(connection, Conv2dConnection): + # self.update = self._conv2d_connection_update + else: + raise NotImplementedError( + "This learning rule is not supported for this Connection type." + ) + + def _connection_update(self, **kwargs) -> None: + # language=rst + """ + Hebbian learning rule for ``Connection`` subclass of ``AbstractConnection`` + class. + """ + + # Add polarities back to feature after updates + if self.enforce_polarity: + self.feature_value = torch.abs(self.feature_value) + + batch_size = self.source.batch_size + + source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float() + source_x = self.source.x.view(batch_size, -1).unsqueeze(2) + target_s = self.target.s.view(batch_size, -1).unsqueeze(1).float() + target_x = self.target.x.view(batch_size, -1).unsqueeze(1) + + # Pre-synaptic update. + update = self.reduction(torch.bmm(source_s, target_x), dim=0) + self.feature_value += self.nu[0] * update + + # Post-synaptic update. + update = self.reduction(torch.bmm(source_x, target_s), dim=0) + self.feature_value += self.nu[1] * update + + # Add polarities back to feature after updates + if self.enforce_polarity: + self.feature_value = self.feature_value * self.polarities + + super().update() + + def reset_state_variables(self): + return + + +class MSTDP(MCC_LearningRule): + # language=rst + """ + Reward-modulated STDP. Adapted from `(Florian 2007) + `_. + """ + + def __init__( + self, + connection: AbstractMulticompartmentConnection, + feature_value: Union[torch.Tensor, float, int], + range: Optional[Sequence[float]] = None, + nu: Optional[Union[float, Sequence[float]]] = None, + reduction: Optional[callable] = None, + decay: float = 0.0, + enforce_polarity: bool = False, + **kwargs, + ) -> None: + # language=rst + """ + Constructor for ``MSTDP`` learning rule. + + :param connection: An ``AbstractConnection`` object whose weights the ``MSTDP`` + learning rule will modify. + :param feature_value: The object which will be altered + :param range: The domain for the feature + :param nu: Single or pair of learning rates for pre- and post-synaptic events, + respectively. + :param reduction: Method for reducing parameter updates along the minibatch + dimension. + :param decay: Coefficient controlling rate of decay of the weights each iteration. + :param enforce_polarity: Will prevent synapses from changing signs if :code:`True` + + Keyword arguments: + + :param average_update: Number of updates to average over, 0=No averaging, x=average over last x updates + :param continues_update: If True, the update will be applied after every update, if False, only after the average_update buffer is full + + :param tc_plus: Time constant for pre-synaptic firing trace. + :param tc_minus: Time constant for post-synaptic firing trace. + """ + super().__init__( + connection=connection, + feature_value=feature_value, + range=[-1, +1] if range is None else range, + nu=nu, + reduction=reduction, + decay=decay, + enforce_polarity=enforce_polarity, + **kwargs, + ) + + if isinstance(connection, (MulticompartmentConnection)): + self.update = self._connection_update + # elif isinstance(connection, Conv2dConnection): + # self.update = self._conv2d_connection_update + else: + raise NotImplementedError( + "This learning rule is not supported for this Connection type." + ) + + self.tc_plus = torch.tensor(kwargs.get("tc_plus", 20.0)) + self.tc_minus = torch.tensor(kwargs.get("tc_minus", 20.0)) + + # Initialize variables for average update and continues update + self.average_update = kwargs.get("average_update", 0) + self.continues_update = kwargs.get("continues_update", False) + + if self.average_update > 0: + self.average_buffer = torch.zeros( + self.average_update, + *self.feature_value.shape, + device=self.feature_value.device, + ) + self.average_buffer_index = 0 + + def _connection_update(self, **kwargs) -> None: + # language=rst + """ + MSTDP learning rule for ``Connection`` subclass of ``AbstractConnection`` class. + + Keyword arguments: + + :param Union[float, torch.Tensor] reward: Reward signal from reinforcement + learning task. + :param float a_plus: Learning rate (post-synaptic). + :param float a_minus: Learning rate (pre-synaptic). + """ + batch_size = self.source.batch_size + + # Initialize eligibility, P^+, and P^-. + if not hasattr(self, "p_plus"): + self.p_plus = torch.zeros( + # batch_size, *self.source.shape, device=self.source.s.device + batch_size, + self.source.n, + device=self.source.s.device, + ) + if not hasattr(self, "p_minus"): + self.p_minus = torch.zeros( + # batch_size, *self.target.shape, device=self.target.s.device + batch_size, + self.target.n, + device=self.target.s.device, + ) + if not hasattr(self, "eligibility"): + self.eligibility = torch.zeros( + batch_size, *self.feature_value.shape, device=self.feature_value.device + ) + + # Reshape pre- and post-synaptic spikes. + source_s = self.source.s.view(batch_size, -1).float() + target_s = self.target.s.view(batch_size, -1).float() + + # Parse keyword arguments. + reward = kwargs["reward"] + a_plus = torch.tensor( + kwargs.get("a_plus", 1.0), device=self.feature_value.device + ) + a_minus = torch.tensor( + kwargs.get("a_minus", -1.0), device=self.feature_value.device + ) + + # Compute weight update based on the eligibility value of the past timestep. + update = reward * self.eligibility + + if self.average_update > 0: + self.average_buffer[self.average_buffer_index] = self.reduction( + update, dim=0 + ) + self.average_buffer_index = ( + self.average_buffer_index + 1 + ) % self.average_update + + if self.continues_update: + self.feature_value += self.nu[0] * torch.mean( + self.average_buffer, dim=0 + ) + elif self.average_buffer_index == 0: + self.feature_value += self.nu[0] * torch.mean( + self.average_buffer, dim=0 + ) + else: + self.feature_value += self.nu[0] * self.reduction(update, dim=0) + + # Update P^+ and P^- values. + self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) + self.p_plus += a_plus * source_s + self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus) + self.p_minus += a_minus * target_s + + # Calculate point eligibility value. + self.eligibility = torch.bmm( + self.p_plus.unsqueeze(2), target_s.unsqueeze(1) + ) + torch.bmm(source_s.unsqueeze(2), self.p_minus.unsqueeze(1)) + + super().update() + + def reset_state_variables(self): + return + + +class MSTDPET(MCC_LearningRule): + # language=rst + """ + Reward-modulated STDP with eligibility trace. Adapted from + `(Florian 2007) `_. + """ + + def __init__( + self, + connection: AbstractMulticompartmentConnection, + feature_value: Union[torch.Tensor, float, int], + range: Optional[Sequence[float]] = None, + nu: Optional[Union[float, Sequence[float]]] = None, + reduction: Optional[callable] = None, + decay: float = 0.0, + enforce_polarity: bool = False, + **kwargs, + ) -> None: + # language=rst + """ + Constructor for ``MSTDPET`` learning rule. + + :param connection: An ``AbstractConnection`` object whose weights the + ``MSTDPET`` learning rule will modify. + :param feature_value: The object which will be altered + :param range: The domain for the feature + :param nu: Single or pair of learning rates for pre- and post-synaptic events, + respectively. + :param reduction: Method for reducing parameter updates along the minibatch + dimension. + :param decay: Coefficient controlling rate of decay of the weights each iteration. + :param enforce_polarity: Will prevent synapses from changing signs if :code:`True` + + Keyword arguments: + + :param float tc_plus: Time constant for pre-synaptic firing trace. + :param float tc_minus: Time constant for post-synaptic firing trace. + :param float tc_e_trace: Time constant for the eligibility trace. + :param average_update: Number of updates to average over, 0=No averaging, x=average over last x updates + :param continues_update: If True, the update will be applied after every update, if False, only after the average_update buffer is full + + """ + super().__init__( + connection=connection, + feature_value=feature_value, + range=[-1, +1] if range is None else range, + nu=nu, + reduction=reduction, + decay=decay, + enforce_polarity=enforce_polarity, + **kwargs, + ) + + if isinstance(connection, (MulticompartmentConnection)): + self.update = self._connection_update + # elif isinstance(connection, Conv2dConnection): + # self.update = self._conv2d_connection_update + else: + raise NotImplementedError( + "This learning rule is not supported for this Connection type." + ) + + self.tc_plus = torch.tensor( + kwargs.get("tc_plus", 20.0) + ) # How long pos reinforcement effects weights + self.tc_minus = torch.tensor( + kwargs.get("tc_minus", 20.0) + ) # How long neg reinforcement effects weights + self.tc_e_trace = torch.tensor( + kwargs.get("tc_e_trace", 25.0) + ) # How long trace effects weights + self.eligibility = torch.zeros( + *self.feature_value.shape, device=self.feature_value.device + ) + self.eligibility_trace = torch.zeros( + *self.feature_value.shape, device=self.feature_value.device + ) + + # Initialize eligibility, eligibility trace, P^+, and P^-. + if not hasattr(self, "p_plus"): + self.p_plus = torch.zeros((self.source.n), device=self.feature_value.device) + if not hasattr(self, "p_minus"): + self.p_minus = torch.zeros( + (self.target.n), device=self.feature_value.device + ) + + # Initialize variables for average update and continues update + self.average_update = kwargs.get("average_update", 0) + self.continues_update = kwargs.get("continues_update", False) + if self.average_update > 0: + self.average_buffer = torch.zeros( + self.average_update, + *self.feature_value.shape, + device=self.feature_value.device, + ) + self.average_buffer_index = 0 + + # @profile + def _connection_update(self, **kwargs) -> None: + # language=rst + """ + MSTDPET learning rule for ``Connection`` subclass of ``AbstractConnection`` + class. + + Keyword arguments: + + :param Union[float, torch.Tensor] reward: Reward signal from reinforcement + learning task. + :param float a_plus: Learning rate (post-synaptic). + :param float a_minus: Learning rate (pre-synaptic). + """ + # Reshape pre- and post-synaptic spikes. + source_s = self.source.s.view(-1).float() + target_s = self.target.s.view(-1).float() + + # Parse keyword arguments. + reward = kwargs["reward"] + a_plus = kwargs.get("a_plus", 1.0) + # if isinstance(a_plus, dict): + # for k, v in a_plus.items(): + # a_plus[k] = torch.tensor(v, device=self.feature_value.device) + # else: + a_plus = torch.tensor(a_plus, device=self.feature_value.device) + a_minus = kwargs.get("a_minus", -1.0) + # if isinstance(a_minus, dict): + # for k, v in a_minus.items(): + # a_minus[k] = torch.tensor(v, device=self.feature_value.device) + # else: + a_minus = torch.tensor(a_minus, device=self.feature_value.device) + + # Calculate value of eligibility trace based on the value + # of the point eligibility value of the past timestep. + # Note: eligibility = [source.n, target.n] > 0 where source and target spiked + # Note: high negs. -> + self.eligibility_trace *= torch.exp( + -self.connection.dt / self.tc_e_trace + ) # Decay + self.eligibility_trace += self.eligibility / self.tc_e_trace # Additive changes + # ^ Also effected by delay in last step + + # Compute weight update. + + if self.average_update > 0: + self.average_buffer[self.average_buffer_index] = ( + self.nu[0] * self.connection.dt * reward * self.eligibility_trace + ) + self.average_buffer_index = ( + self.average_buffer_index + 1 + ) % self.average_update + + if self.continues_update: + self.feature_value += torch.mean(self.average_buffer, dim=0) + elif self.average_buffer_index == 0: + self.feature_value += torch.mean(self.average_buffer, dim=0) + else: + self.feature_value += ( + self.nu[0] * self.connection.dt * reward * self.eligibility_trace + ) + + # Update P^+ and P^- values. + self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) # Decay + self.p_plus += a_plus * source_s # Scaled source spikes + self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus) # Decay + self.p_minus += a_minus * target_s # Scaled target spikes + + # Notes: + # + # a_plus -> How much a spike in src contributes to the eligibility + # a_minus -> How much a spike in trg contributes to the eligibility (neg) + # p_plus -> +a_plus every spike, with decay + # p_minus -> +a_minus every spike, with decay + + # Calculate point eligibility value. + self.eligibility = torch.outer(self.p_plus, target_s) + torch.outer( + source_s, self.p_minus + ) + + super().update() + + def reset_state_variables(self) -> None: + self.eligibility.zero_() + self.eligibility_trace.zero_() + return diff --git a/bindsnet/network/monitors.py b/bindsnet/network/monitors.py index 7dc7cc3e..f11a2339 100644 --- a/bindsnet/network/monitors.py +++ b/bindsnet/network/monitors.py @@ -4,9 +4,17 @@ import numpy as np import torch +import numpy as np + +from abc import ABC +from typing import Union, Optional, Iterable, Dict from bindsnet.network.nodes import Nodes -from bindsnet.network.topology import AbstractConnection +from bindsnet.network.topology import ( + AbstractConnection, + AbstractMulticompartmentConnection, +) +from bindsnet.network.topology_features import AbstractFeature if TYPE_CHECKING: from .network import Network @@ -27,7 +35,12 @@ class Monitor(AbstractMonitor): def __init__( self, - obj: Union[Nodes, AbstractConnection], + obj: Union[ + Nodes, + AbstractConnection, + AbstractMulticompartmentConnection, + AbstractFeature, + ], state_vars: Iterable[str], time: Optional[int] = None, batch_size: int = 1, @@ -54,8 +67,6 @@ def __init__( if self.time is None: self.device = "cpu" - self.clean = True - self.recording = [] self.reset_state_variables() @@ -179,7 +190,20 @@ def __init__( self.time, *getattr(self.network.connections[c], v).size() ) - def get(self) -> Dict[str, Dict[str, Union[Nodes, AbstractConnection]]]: + def get( + self, + ) -> Dict[ + str, + Dict[ + str, + Union[ + Nodes, + AbstractConnection, + AbstractMulticompartmentConnection, + AbstractFeature, + ], + ], + ]: # language=rst """ Return entire recording to user. diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index 98ac2555..cb5fafa1 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -1,8 +1,11 @@ from abc import ABC, abstractmethod from typing import Optional, Sequence, Tuple, Union +import warnings + import numpy as np import torch +from torch import device import torch.nn.functional as F from bindsnet.utils import im2col_indices from torch.nn import Module, Parameter @@ -139,6 +142,112 @@ def reset_state_variables(self) -> None: """ +class AbstractMulticompartmentConnection(ABC, Module): + # language=rst + """ + Abstract base method for connections between ``Nodes``. + """ + + def __init__( + self, + source: Nodes, + target: Nodes, + device: device, + pipeline: list = None, + **kwargs, + ) -> None: + # language=rst + """ + Constructor for abstract base class for connection objects. + + :param source: A layer of nodes from which the connection originates. + :param target: A layer of nodes to which the connection connects. + :param device: The device which the connection will run on + :param pipeline: An ordered list of topology features to be used on the connection + """ + + super().__init__() + + #### General Assertions #### + assert isinstance(source, Nodes), "Source is not a Nodes object" + assert isinstance(target, Nodes), "Target is not a Nodes object" + + #### Assign class variables #### + self.source = source + self.target = target + self.device = device + self.pipeline = ( + [] if pipeline is None else pipeline + ) # <- *Ordered* executables for features + + # TODO: Make it so there can't be repeated names!!! + # Initialize feature index & prime + self.feature_index = ( + {} + ) # <- *Unordered* and named set of references for features + for feature in pipeline: + self.feature_index[feature.name] = feature + feature.prime_feature(connection=self, device=self.device, **kwargs) + + @abstractmethod + def compute(self, s: torch.Tensor) -> None: + # language=rst + """ + Compute pre-activations of downstream neurons given spikes of upstream neurons. + + :param s: Incoming spikes. + """ + pass + + @abstractmethod + def update(self, **kwargs) -> None: + # language=rst + """ + Compute connection's update rule. + + Keyword arguments: + + :param bool learning: Whether to allow connection updates. + """ + pass + + @abstractmethod + def reset_state_variables(self) -> None: + # language=rst + """ + Contains resetting logic for the connection. + """ + pass + + def append_pipeline(self, feature) -> None: + # language=rst + """ + Append a feature to the pipeline + """ + self.pipeline.append(feature) + feature.prime_feature(connection=self, device=self.device) + self.feature_index[feature.name] = feature + + def insert_pipeline(self, feature, index) -> None: + # language=rst + """ + insert a feature into the pipeline + :param index: Index for where to insert the feature + """ + self.pipeline.insert(feature, index) + feature.prime_feature(connection=self, device=self.device) + self.feature_index[feature.name] = feature + + def remove_pipeline(self, feature) -> None: + # language=rst + """ + remove a feature frome the pipeline + :param feature: feature to be removed + """ + self.pipeline.remove(feature) + del self.feature_index[feature.name] + + class Connection(AbstractConnection): # language=rst """ @@ -272,6 +381,133 @@ def reset_state_variables(self) -> None: super().reset_state_variables() +class MulticompartmentConnection(AbstractMulticompartmentConnection): + # language=rst + """ + Specifies synapses between one or two populations of neurons. + """ + + def __init__( + self, + source: Nodes, + target: Nodes, + device: device, + pipeline: list = [], + manual_update: bool = False, + traces: bool = False, + **kwargs, + ) -> None: + # language=rst + """ + Instantiates a :code:`Connection` object. + + :param source: A layer of nodes from which the connection originates. + :param target: A layer of nodes to which the connection connects. + :param device: The device the connection will be run on. + :param list: Pipeline of features for the connection signals to be run through + :param manual_update: Set to :code:`True` to disable automatic updates (applying learning rules) to connection features. + False by default, updates called after each time step + :param traces: Set to :code:`True` to record history of connection activity (for monitors) + """ + + super().__init__(source, target, device, pipeline, **kwargs) + self.traces = traces + self.manual_update = manual_update + if self.traces: + self.activity = None + + def compute(self, s: torch.Tensor) -> torch.Tensor: + # language=rst + """ + Compute pre-activations given spikes using connection weights. + + :param s: Incoming spikes. + :return: Incoming spikes multiplied by synaptic weights (with or without + decaying spike activation). + """ + + # Change to numeric type (torch doesn't like booleans for matrix ops) + # Note: .float() is an expensive operation. Use as minimally as possible! + # if s.dtype != torch.float32: + # s = s.float() + + # Prepare broadcast from incoming spikes to all output neurons + # |conn_spikes| = [batch_size, source.n * target.n] + conn_spikes = s.view(s.size(0), self.source.n, 1).repeat(1, 1, self.target.n) + # TODO: ^ This could probably be optimized + + # Run through pipeline + for f in self.pipeline: + conn_spikes = f.compute(conn_spikes) + + # Sum signals for each of the output/terminal neurons + # |out_signal| = [batch_size, target.n] + out_signal = conn_spikes.view(s.size(0), self.source.n, self.target.n).sum(1) + + if self.traces: + self.activity = out_signal + + return out_signal.view(s.size(0), *self.target.shape) + + def compute_window(self, s: torch.Tensor) -> torch.Tensor: + # language=rst + """""" + + if self.s_w == None: + # Construct a matrix of shape batch size * window size * dimension of layer + self.s_w = torch.zeros( + self.target.batch_size, self.target.res_window_size, *self.source.shape + ) + + # Add the spike vector into the first in first out matrix of windowed (res) spike trains + self.s_w = torch.cat((self.s_w[:, 1:, :], s[:, None, :]), 1) + + # Compute multiplication of spike activations by weights and add bias. + if self.b is None: + post = ( + self.s_w.view(self.s_w.size(0), self.s_w.size(1), -1).float() @ self.w + ) + else: + post = ( + self.s_w.view(self.s_w.size(0), self.s_w.size(1), -1).float() @ self.w + + self.b + ) + + return post.view( + self.s_w.size(0), self.target.res_window_size, *self.target.shape + ) + + def update(self, **kwargs) -> None: + # language=rst + """ + Compute connection's update rule. + """ + learning = kwargs.get("learning", False) + if learning and not self.manual_update: + # Pipeline learning + for f in self.pipeline: + f.update(**kwargs) + + def normalize(self) -> None: + # language=rst + """ + Normalize all features in the connection. + """ + # Normalize pipeline features + for f in self.pipeline: + f.normalize() + + def reset_state_variables(self) -> None: + # language=rst + """ + Contains resetting logic for the connection. + """ + super().reset_state_variables() + + for f in self.pipeline: + f.reset_state_variables() + + class Conv1dConnection(AbstractConnection): # language=rst """ diff --git a/bindsnet/network/topology_features.py b/bindsnet/network/topology_features.py new file mode 100644 index 00000000..f99cf39f --- /dev/null +++ b/bindsnet/network/topology_features.py @@ -0,0 +1,940 @@ +from abc import ABC, abstractmethod +from bindsnet.learning.learning import NoOp +from typing import Union, Tuple, Optional, Sequence + +import numpy as np +import torch +from torch import device +from torch.nn import Parameter +import torch.nn.functional as F +import torch.nn as nn +import bindsnet.learning + + +class AbstractFeature(ABC): + # language=rst + """ + Features to operate on signals traversing a connection. + """ + + @abstractmethod + def __init__( + self, + name: str, + value: Union[torch.Tensor, float, int] = None, + range: Optional[Union[list, tuple]] = None, + clamp_frequency: Optional[int] = 1, + norm: Optional[Union[torch.Tensor, float, int]] = None, + learning_rule: Optional[bindsnet.learning.LearningRule] = None, + nu: Optional[Union[list, tuple, int, float]] = None, + reduction: Optional[callable] = None, + enforce_polarity: Optional[bool] = False, + decay: float = 0.0, + parent_feature=None, + **kwargs, + ) -> None: + # language=rst + """ + Instantiates a :code:`Feature` object. Will assign all incoming arguments as class variables + :param name: Name of the feature + :param value: Core numeric object for the feature. This parameters function will vary depending on the feature + :param range: Range of acceptable values for the :code:`value` parameter + :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each + sample and after the value has been updated by the learning rule (if there is one) + :param learning_rule: Rule which will modify the :code:`value` after each sample + :param nu: Learning rate for the learning rule + :param reduction: Method for reducing parameter updates along the minibatch + dimension + :param decay: Constant multiple to decay weights by on each iteration + :param parent_feature: Parent feature to inherit :code:`value` from + """ + + #### Initialize class variables #### + ## Args ## + self.name = name + self.value = value + self.range = [-1.0, 1.0] if range is None else range + self.clamp_frequency = clamp_frequency + self.norm = norm + self.learning_rule = learning_rule + self.nu = nu + self.reduction = reduction + self.decay = decay + self.parent_feature = parent_feature + self.kwargs = kwargs + + ## Backend ## + self.is_primed = False + + from ..learning.MCC_learning import ( + NoOp, + PostPre, + MSTDP, + MSTDPET, + ) + + supported_rules = [ + NoOp, + PostPre, + MSTDP, + MSTDPET, + ] + + #### Assertions #### + # Assert correct instance of feature values + assert isinstance(name, str), "Feature {0}'s name should be of type str".format( + name + ) + assert value is None or isinstance( + value, (torch.Tensor, float, int) + ), "Feature {0} should be of type float, int, or torch.Tensor, not {1}".format( + name, type(value) + ) + assert norm is None or isinstance( + norm, (torch.Tensor, float, int) + ), "Feature {0}'s norm should be of type float, int, or torch.Tensor, not {1}".format( + name, type(norm) + ) + assert learning_rule is None or ( + learning_rule in supported_rules + ), "Feature {0}'s learning_rule should be of type bindsnet.LearningRule not {1}".format( + name, type(learning_rule) + ) + assert nu is None or isinstance( + nu, (list, tuple) + ), "Feature {0}'s nu should be of type list or tuple, not {1}".format( + name, type(nu) + ) + assert reduction is None or isinstance( + reduction, callable + ), "Feature {0}'s reduction should be of type callable, not {1}".format( + name, type(reduction) + ) + assert decay is None or isinstance( + decay, float + ), "Feature {0}'s decay should be of type float, not {1}".format( + name, type(decay) + ) + + self.assert_valid_range() + if value is not None: + self.assert_feature_in_range() + + @abstractmethod + def reset_state_variables(self) -> None: + # language=rst + """ + Contains resetting logic for the feature. + """ + if self.learning_rule: + self.learning_rule.reset_state_variables() + pass + + @abstractmethod + def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: + # language=rst + """ + Computes the feature being operated on a set of incoming signals. + """ + pass + + def prime_feature(self, connection, device, **kwargs) -> None: + # language=rst + """ + Prepares a feature after it has been placed in a connection. This takes care of learning rules, feature + value initialization, and asserting that features have proper shape. Should occur after primary constructor. + """ + + # Note: DO NOT move NoOp to global; cyclical dependency + from ..learning.MCC_learning import NoOp + + # Check if feature is already primed + if self.is_primed: + return + self.is_primed = True + + # Check if feature is a child feature + if self.parent_feature is not None: + self.link(self.parent_feature) + self.learning_rule = NoOp(connection=connection) + return + + # Check if values/norms are the correct shape + if isinstance(self.value, torch.Tensor): + assert tuple(self.value.shape) == (connection.source.n, connection.target.n) + + if self.norm is not None and isinstance(self.norm, torch.Tensor): + assert self.norm.shape[0] == connection.target.n + + #### Initialize feature value #### + if self.value is None: + self.value = ( + self.initialize_value() + ) # This should be defined per the sub-class + + if isinstance(self.value, (int, float)): + self.value = torch.Tensor([self.value]) + + # Parameterize and send to proper device + # Note: Floating is used here to avoid dtype conflicts + self.value = Parameter(self.value, requires_grad=False).to(device) + + ##### Initialize learning rule ##### + + # Default is NoOp + if self.learning_rule is None: + self.learning_rule = NoOp + + self.learning_rule = self.learning_rule( + connection=connection, + feature_value=self.value, + range=self.range, + nu=self.nu, + reduction=self.reduction, + decay=self.decay, + **kwargs, + ) + + #### Recycle unnecessary variables #### + del self.nu, self.reduction, self.decay, self.range + + def update(self, **kwargs) -> None: + # language=rst + """ + Compute feature's update rule + """ + + self.learning_rule.update(**kwargs) + + def normalize(self) -> None: + # language=rst + """ + Normalize feature so each target neuron has sum of feature values equal to + ``self.norm``. + """ + + if self.norm is not None: + abs_sum = self.value.sum(0).unsqueeze(0) + abs_sum[abs_sum == 0] = 1.0 + self.value *= self.norm / abs_sum + + def degrade(self) -> None: + # language=rst + """ + Degrade the value of the propagated spikes according to the features value. A lambda function should be passed + into the constructor which takes a single argument (which represent the value), and returns a value which will + be *subtracted* from the propagated spikes. + """ + + return self.degrade(self.value) + + def link(self, parent_feature) -> None: + # language=rst + """ + Allow two features to share tensor values + """ + + valid_features = (Probability, Weight, Bias, Intensity) + + assert isinstance(self, valid_features), f"A {self} cannot use feature linking" + assert isinstance( + parent_feature, valid_features + ), f"A {parent_feature} cannot use feature linking" + assert self.is_primed, f"Prime feature before linking: {self}" + assert ( + parent_feature.is_primed + ), f"Prime parent feature before linking: {parent_feature}" + + # Link values, disable learning for this feature + self.value = parent_feature.value + self.learning_rule = NoOp + + def assert_valid_range(self): + # language=rst + """ + Default range verifier (within [-1, +1]) + """ + + r = self.range + + ## Check dtype ## + assert isinstance( + self.range, (list, tuple) + ), f"Invalid range for feature {self.name}: range should be a list or tuple, not {type(self.range)}" + assert ( + len(r) == 2 + ), f"Invalid range for feature {self.name}: range should have a length of 2" + + ## Check min/max relation ## + if isinstance(r[0], torch.Tensor) or isinstance(r[1], torch.Tensor): + assert ( + r[0] < r[1] + ).all(), f"Invalid range for feature {self.name}: a min is larger than an adjacent max" + else: + assert ( + r[0] < r[1] + ), f"Invalid range for feature {self.name}: the min value is larger than the max value" + + def assert_feature_in_range(self): + r = self.range + f = self.value + + if isinstance(r[0], torch.Tensor) or isinstance(f, torch.Tensor): + assert ( + f >= r[0] + ).all(), f"Feature out of range for {self.name}: Features values not in [{r[0]}, {r[1]}]" + else: + assert ( + f >= r[0] + ), f"Feature out of range for {self.name}: Features values not in [{r[0]}, {r[1]}]" + + if isinstance(r[1], torch.Tensor) or isinstance(f, torch.Tensor): + assert ( + f <= r[1] + ).all(), f"Feature out of range for {self.name}: Features values not in [{r[0]}, {r[1]}]" + else: + assert ( + f <= r[1] + ), f"Feature out of range for {self.name}: Features values not in [{r[0]}, {r[1]}]" + + def assert_valid_shape(self, source_shape, target_shape, f): + # Multidimensional feat + if len(f.shape) > 1: + assert f.shape == ( + source_shape, + target_shape, + ), f"Feature {self.name} has an incorrect shape of {f.shape}. Should be of shape {(source_shape, target_shape)}" + # Else assume scalar, which is a valid shape + + +class Probability(AbstractFeature): + def __init__( + self, + name: str, + value: Union[torch.Tensor, float, int] = None, + range: Optional[Sequence[float]] = None, + norm: Optional[Union[torch.Tensor, float, int]] = None, + learning_rule: Optional[bindsnet.learning.LearningRule] = None, + nu: Optional[Union[list, tuple]] = None, + reduction: Optional[callable] = None, + decay: float = 0.0, + parent_feature=None, + ) -> None: + # language=rst + """ + Will run a bernoulli trial using :code:`value` to determine if a signal will successfully traverse the synapse + :param name: Name of the feature + :param value: Number(s) in [0, 1] which represent the probability of a signal traversing a synapse. Tensor values + assume that probabilities will be matched to adjacent synapses in the connection. Scalars will be applied to + all synapses. + :param range: Range of acceptable values for the :code:`value` parameter. Should be in [0, 1] + :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample + and after the value has been updated by the learning rule (if there is one) + :param learning_rule: Rule which will modify the :code:`value` after each sample + :param nu: Learning rate for the learning rule + :param reduction: Method for reducing parameter updates along the minibatch + dimension + :param decay: Constant multiple to decay weights by on each iteration + :param parent_feature: Parent feature to inherit :code:`value` from + """ + + ### Assertions ### + super().__init__( + name=name, + value=value, + range=[0, 1] if range is None else range, + norm=norm, + learning_rule=learning_rule, + nu=nu, + reduction=reduction, + decay=decay, + parent_feature=parent_feature, + ) + + def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: + return conn_spikes * torch.bernoulli(self.value) + + def reset_state_variables(self) -> None: + pass + + def prime_feature(self, connection, device, **kwargs) -> None: + ## Initialize value ### + if self.value is None: + self.initialize_value = lambda: torch.clamp( + torch.rand(connection.source.n, connection.target.n, device=device), + self.range[0], + self.range[1], + ) + + super().prime_feature(connection, device, **kwargs) + + def assert_valid_range(self): + super().assert_valid_range() + + r = self.range + + ## Check min greater than 0 ## + if isinstance(r[0], torch.Tensor): + assert ( + r[0] >= 0 + ).all(), ( + f"Invalid range for feature {self.name}: a min value is less than 0" + ) + elif isinstance(r[0], (float, int)): + assert ( + r[0] >= 0 + ), f"Invalid range for feature {self.name}: the min value is less than 0" + else: + assert ( + False + ), f"Invalid range for feature {self.name}: the min value must be of type torch.Tensor, float, or int" + + +class Mask(AbstractFeature): + def __init__( + self, + name: str, + value: Union[torch.Tensor, float, int] = None, + ) -> None: + # language=rst + """ + Boolean mask which determines whether or not signals are allowed to traverse certain synapses. + :param name: Name of the feature + :param value: Boolean mask. :code:`True` means a signal can pass, :code:`False` means the synapse is impassable + """ + + ### Assertions ### + if isinstance(value, torch.Tensor): + assert ( + value.dtype == torch.bool + ), "Mask must be of type bool, not {}".format(value.dtype) + elif value is not None: + assert isinstance(value, bool), "Mask must be of type bool, not {}".format( + value.dtype + ) + + # Send boolean to tensor (priming wont work if it's not a tensor) + value = torch.tensor(value) + + super().__init__( + name=name, + value=value, + ) + + self.name = name + self.value = value + + def compute(self, conn_spikes) -> torch.Tensor: + return conn_spikes * self.value + + def reset_state_variables(self) -> None: + pass + + def prime_feature(self, connection, device, **kwargs) -> None: + # Check if feature is already primed + if self.is_primed: + return + self.is_primed = True + + #### Initialize feature value #### + if self.value is None: + self.value = ( + torch.rand(connection.source.n, connection.target.n) > 0.99 + ).to(device=device) + self.value = Parameter(self.value, requires_grad=False).to(device) + + #### Assertions #### + # Check if tensor values are the correct shape + if isinstance(self.value, torch.Tensor): + self.assert_valid_shape( + connection.source.n, connection.target.n, self.value + ) + + ##### Initialize learning rule ##### + # Note: DO NOT move NoOp to global; cyclical dependency + from ..learning.MCC_learning import NoOp + + # Default is NoOp + if self.learning_rule is None: + self.learning_rule = NoOp + + self.learning_rule = self.learning_rule( + connection=connection, + feature=self.value, + range=self.range, + nu=self.nu, + reduction=self.reduction, + decay=self.decay, + **kwargs, + ) + + +class MeanField(AbstractFeature): + def __init__(self) -> None: + # language=rst + """ + Takes the mean of all outgoing signals, and outputs that mean across every synapse in the connection + """ + pass + + def reset_state_variables(self) -> None: + pass + + def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: + return conn_spikes.mean() * torch.ones( + self.source_n * self.target_n, device=self.device + ) + + def prime_feature(self, connection, device, **kwargs) -> None: + self.source_n = connection.source.n + self.target_n = connection.target.n + + super().prime_feature(connection, device, **kwargs) + + +class Weight(AbstractFeature): + def __init__( + self, + name: str, + value: Union[torch.Tensor, float, int] = None, + range: Optional[Sequence[float]] = None, + norm: Optional[Union[torch.Tensor, float, int]] = None, + norm_frequency: Optional[str] = "sample", + learning_rule: Optional[bindsnet.learning.LearningRule] = None, + nu: Optional[Union[list, tuple]] = None, + reduction: Optional[callable] = None, + enforce_polarity: Optional[bool] = False, + decay: float = 0.0, + ) -> None: + # language=rst + """ + Multiplies signals by scalars + :param name: Name of the feature + :param value: Values to scale signals by + :param range: Range of acceptable values for the :code:`value` parameter + :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample + and after the value has been updated by the learning rule (if there is one) + :param norm_frequency: How often to normalize weights: + * 'sample': weights normalized after each sample + * 'time step': weights normalized after each time step + :param learning_rule: Rule which will modify the :code:`value` after each sample + :param nu: Learning rate for the learning rule + :param reduction: Method for reducing parameter updates along the minibatch + dimension + :param enforce_polarity: Will prevent synapses from changing signs if :code:`True` + :param decay: Constant multiple to decay weights by on each iteration + """ + + self.norm_frequency = norm_frequency + self.enforce_polarity = enforce_polarity + super().__init__( + name=name, + value=value, + range=[-torch.inf, +torch.inf] if range is None else range, + norm=norm, + learning_rule=learning_rule, + nu=nu, + reduction=reduction, + decay=decay, + ) + + def reset_state_variables(self) -> None: + pass + + def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: + if self.enforce_polarity: + pos_mask = ~torch.logical_xor(self.value > 0, self.positive_mask) + neg_mask = ~torch.logical_xor(self.value < 0, ~self.positive_mask) + self.value = self.value * torch.logical_or(pos_mask, neg_mask) + self.value[~pos_mask] = 0.0001 + self.value[~neg_mask] = -0.0001 + + return_val = self.value * conn_spikes + if self.norm_frequency == "time step": + self.normalize(time_step_norm=True) + + return return_val + + def prime_feature(self, connection, device, **kwargs) -> None: + #### Initialize value #### + if self.value is None: + self.initialize_value = lambda: torch.rand( + connection.source.n, connection.target.n + ) + + super().prime_feature( + connection, device, enforce_polarity=self.enforce_polarity, **kwargs + ) + if self.enforce_polarity: + self.positive_mask = ((self.value > 0).sum(1) / self.value.shape[1]) > 0.5 + tmp = torch.zeros_like(self.value) + tmp[self.positive_mask, :] = 1 + self.positive_mask = tmp.bool() + + def normalize(self, time_step_norm=False) -> None: + # 'time_step_norm' will indicate if normalize is being called from compute() + # or from network.py (after a sample is completed) + + if self.norm_frequency == "time step" and time_step_norm: + super().normalize() + + if self.norm_frequency == "sample" and not time_step_norm: + super().normalize() + + +class Bias(AbstractFeature): + def __init__( + self, + name: str, + value: Union[torch.Tensor, float, int] = None, + range: Optional[Sequence[float]] = None, + norm: Optional[Union[torch.Tensor, float, int]] = None, + ) -> None: + # language=rst + """ + Adds scalars to signals + :param name: Name of the feature + :param value: Values to add to the signals + :param range: Range of acceptable values for the :code:`value` parameter + :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample + and after the value has been updated by the learning rule (if there is one) + """ + + super().__init__( + name=name, + value=value, + range=[-torch.inf, +torch.inf] if range is None else range, + norm=norm, + ) + + def reset_state_variables(self) -> None: + pass + + def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: + return conn_spikes + self.value + + def prime_feature(self, connection, device, **kwargs) -> None: + #### Initialize value #### + if self.value is None: + self.initialize_value = lambda: torch.rand( + connection.source.n, connection.target.n + ) + + super().prime_feature(connection, device, **kwargs) + + +class Intensity(AbstractFeature): + def __init__( + self, + name: str, + value: Union[torch.Tensor, float, int] = None, + range: Optional[Sequence[float]] = None, + ) -> None: + # language=rst + """ + Adds scalars to signals + :param name: Name of the feature + :param value: Values to scale signals by + """ + + super().__init__(name=name, value=value, range=range) + + def reset_state_variables(self) -> None: + pass + + def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: + return conn_spikes * self.value + + def prime_feature(self, connection, device, **kwargs) -> None: + #### Initialize value #### + if self.value is None: + self.initialize_value = lambda: torch.clamp( + torch.sign( + torch.randint(-1, +2, (connection.source.n, connection.target.n)) + ), + self.range[0], + self.range[1], + ) + + super().prime_feature(connection, device, **kwargs) + + +class Degradation(AbstractFeature): + def __init__( + self, + name: str, + value: Union[torch.Tensor, float, int] = None, + degrade_function: callable = None, + parent_feature: Optional[AbstractFeature] = None, + ) -> None: + # language=rst + """ + Degrades propagating spikes according to :code:`degrade_function`. + Note: If :code:`parent_feature` is provided, it will override :code:`value`. + :param name: Name of the feature + :param value: Value used to degrade feature + :param degrade_function: Callable function which takes a single argument (:code:`value`) and returns a tensor or + constant to be *subtracted* from the propagating spikes. + :param parent_feature: Parent feature with desired :code:`value` to inherit + """ + + # Note: parent_feature will override value. See abstract constructor + super().__init__(name=name, value=value, parent_feature=parent_feature) + + self.degrade_function = degrade_function + + def reset_state_variables(self) -> None: + pass + + def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: + return conn_spikes - self.degrade_function(self.value) + + +class AdaptationBaseSynapsHistory(AbstractFeature): + def __init__( + self, + name: str, + value: Union[torch.Tensor, float, int] = None, + ann_values: Union[list, tuple] = None, + const_update_rate: float = 0.1, + const_decay: float = 0.001, + ) -> None: + # language=rst + """ + The ANN will be use on each synaps to messure the previous activity of the neuron and descide to close or open connection. + + :param name: Name of the feature + :param ann_values: Values to be use to build an ANN that will adapt the connectivity of the layer. + :param value: Values to be use to build an initial mask for the synapses. + :param const_update_rate: The mask upatate rate of the ANN decision. + :param const_decay: The spontaneous activation of the synapses. + """ + + # Define the ANN + class ANN(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(ANN, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size, bias=False) + self.fc2 = nn.Linear(hidden_size, output_size, bias=False) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + x = torch.tanh(self.fc2(x)) # MUST HAVE output between -1 and 1 + return x + + self.init_value = value.clone().detach() # initial mask + self.mask = value # final decision of the ANN + value = torch.zeros_like(value) # initial mask + self.ann = ANN(ann_values[0].shape[0], ann_values[0].shape[1], 1) + + # load weights from ann_values + with torch.no_grad(): + self.ann.fc1.weight.data = ann_values[0] + self.ann.fc2.weight.data = ann_values[1] + self.ann.to(ann_values[0].device) + + self.spike_buffer = torch.zeros( + (value.numel(), ann_values[0].shape[1]), + device=ann_values[0].device, + dtype=torch.bool, + ) + self.counter = 0 + self.start_counter = False + self.const_update_rate = const_update_rate + self.const_decay = const_decay + + super().__init__(name=name, value=value) + + def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: + + # Update the spike buffer + if self.start_counter == False or conn_spikes.sum() > 0: + self.start_counter = True + self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = ( + conn_spikes.flatten() + ) + self.counter += 1 + + # Update the masks + if self.counter % self.spike_buffer.shape[1] == 0: + with torch.no_grad(): + ann_decision = self.ann(self.spike_buffer.to(torch.float32)) + self.mask += ( + ann_decision.view(self.mask.shape) * self.const_update_rate + ) # update mask with learning rate fraction + self.mask += self.const_decay # spontaneous activate synapses + self.mask = torch.clamp(self.mask, -1, 1) # cap the mask + + # self.mask = torch.clamp(self.mask, -1, 1) + self.value = (self.mask > 0).float() + + return conn_spikes * self.value + + def reset_state_variables( + self, + ): + self.spike_buffer = torch.zeros_like(self.spike_buffer) + self.counter = 0 + self.start_counter = False + self.value = self.init_value.clone().detach() # initial mask + pass + + +class AdaptationBaseOtherSynaps(AbstractFeature): + def __init__( + self, + name: str, + value: Union[torch.Tensor, float, int] = None, + ann_values: Union[list, tuple] = None, + const_update_rate: float = 0.1, + const_decay: float = 0.01, + ) -> None: + # language=rst + """ + The ANN will be use on each synaps to messure the previous activity of the neuron and descide to close or open connection. + + :param name: Name of the feature + :param ann_values: Values to be use to build an ANN that will adapt the connectivity of the layer. + :param value: Values to be use to build an initial mask for the synapses. + :param const_update_rate: The mask upatate rate of the ANN decision. + :param const_decay: The spontaneous activation of the synapses. + """ + + # Define the ANN + class ANN(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(ANN, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size, bias=False) + self.fc2 = nn.Linear(hidden_size, output_size, bias=False) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + x = torch.tanh(self.fc2(x)) # MUST HAVE output between -1 and 1 + return x + + self.init_value = value.clone().detach() # initial mask + self.mask = value # final decision of the ANN + value = torch.zeros_like(value) # initial mask + self.ann = ANN(ann_values[0].shape[0], ann_values[0].shape[1], 1) + + # load weights from ann_values + with torch.no_grad(): + self.ann.fc1.weight.data = ann_values[0] + self.ann.fc2.weight.data = ann_values[1] + self.ann.to(ann_values[0].device) + + self.spike_buffer = torch.zeros( + (value.numel(), ann_values[0].shape[1]), + device=ann_values[0].device, + dtype=torch.bool, + ) + self.counter = 0 + self.start_counter = False + self.const_update_rate = const_update_rate + self.const_decay = const_decay + + super().__init__(name=name, value=value) + + def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: + + # Update the spike buffer + if self.start_counter == False or conn_spikes.sum() > 0: + self.start_counter = True + self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = ( + conn_spikes.flatten() + ) + self.counter += 1 + + # Update the masks + if self.counter % self.spike_buffer.shape[1] == 0: + with torch.no_grad(): + ann_decision = self.ann(self.spike_buffer.to(torch.float32)) + self.mask += ( + ann_decision.view(self.mask.shape) * self.const_update_rate + ) # update mask with learning rate fraction + self.mask += self.const_decay # spontaneous activate synapses + self.mask = torch.clamp(self.mask, -1, 1) # cap the mask + + # self.mask = torch.clamp(self.mask, -1, 1) + self.value = (self.mask > 0).float() + + return conn_spikes * self.value + + def reset_state_variables( + self, + ): + self.spike_buffer = torch.zeros_like(self.spike_buffer) + self.counter = 0 + self.start_counter = False + self.value = self.init_value.clone().detach() # initial mask + pass + + +### Sub Features ### + + +class AbstractSubFeature(ABC): + # language=rst + """ + A way to inject a features methods (like normalization, learning, etc.) into the pipeline for user controlled + execution. + """ + + @abstractmethod + def __init__( + self, + name: str, + parent_feature: AbstractFeature, + ) -> None: + # language=rst + """ + Instantiates a :code:`Augment` object. Will assign all incoming arguments as class variables. + :param name: Name of the augment + :param parent_feature: Primary feature which the augment will modify + """ + + self.name = name + self.parent = parent_feature + self.sub_feature = None # <-- Defined in non-abstract constructor + + def compute(self, _) -> None: + # language=rst + """ + Proxy function to catch a pipeline execution from topology.py's :code:`compute` function. Allows :code:`SubFeature` + objects to be executed like real features in the pipeline. + """ + + # sub_feature should be defined in the non-abstract constructor + self.sub_feature() + + +class Normalization(AbstractSubFeature): + # language=rst + """ + Normalize parent features values so each target neuron has sum of feature values equal to a desired value :code:`norm`. + """ + + def __init__( + self, + name: str, + parent_feature: AbstractFeature, + ) -> None: + super().__init__(name, parent_feature) + + self.sub_feature = self.parent.normalize + + +class Updating(AbstractSubFeature): + # language=rst + """ + Update parent features values using the assigned update rule. + """ + + def __init__( + self, + name: str, + parent_feature: AbstractFeature, + ) -> None: + super().__init__(name, parent_feature) + + self.sub_feature = self.parent.update diff --git a/docs/source/guide/guide_part_i.rst b/docs/source/guide/guide_part_i.rst index ff425e8b..d8bc35d7 100644 --- a/docs/source/guide/guide_part_i.rst +++ b/docs/source/guide/guide_part_i.rst @@ -145,6 +145,33 @@ weights based on pre-, post-synaptic activity and possibly other signals; e.g., :code:`normalize` (for ensuring weights incident to post-synaptic neurons sum to a pre-specified value), and :code:`reset_state_variables` (for re-initializing stateful variables for the start of a new simulation). +For more complex connections, the MulticompartmentConnection class can be used. The MulticompartmentConnection will pass spikes through different "features" +such as weights, bias's, and boolean masks in a specified order. Features are passed to the MulticompartmentConnection constructor in a list, and executed in order. +For example, the code below uses a pipeline containing a weight and bias feature. During runtime, spikes from the source will be multiplied by the weights first, +then a bias added second. Additional features can be added before/after/between these two. +To create a simple all-to-all connection with a weight and bias: + +.. code-block:: python + + from bindsnet.network.nodes import Input, LIFNodes + from bindsnet.network.topology import MulticompartmentConnection + from bindsnet.network.topology_features import Weight, Bias + + # Create two populations of neurons, one to act as the "source" + # population, and the other, the "target population". + source_layer = Input(n=100) + target_layer = LIFNodes(n=1000) + + # Create 'pipeline' of features that spikes will pass through + weights = Weight(name='weight_feature', value=torch.rand(100, 1000)) + bias = Bias(name='bias_feature', value=torch.rand(100, 1000)) + + # Connect the two layers. + connection = MulticompartmentConnection( + source=source_layer, target=target_layer, + pipeline=[weight, bias] + ) + Specifying monitors ******************* @@ -176,7 +203,9 @@ course of simulation in certain network components. To create a monitor to monit The user must specify a :code:`Nodes` or :code:`AbstractConnection` object from which to record, attributes of that object to record (:code:`state_vars`), and, optionally, how many time steps the simulation(s) will last, in order to -save time by pre-allocating memory. +save time by pre-allocating memory. + +Monitors are not officially supported for MulticompartmentConnection To add a monitor to the network (thereby enabling monitoring), use the :code:`add_monitor` function of the :py:class:`bindsnet.network.Network` class: diff --git a/docs/source/guide/guide_part_ii.rst b/docs/source/guide/guide_part_ii.rst index 13ef5b31..786753a3 100644 --- a/docs/source/guide/guide_part_ii.rst +++ b/docs/source/guide/guide_part_ii.rst @@ -74,4 +74,33 @@ Custom learning rules can be implemented by subclassing :code:`bindsnet.learning and providing implementations for the types of :code:`AbstractConnection` objects intended to be used. For example, the :code:`Connection` and :code:`LocalConnection` objects rely on the implementation of a private method, :code:`_connection_update`, whereas the :code:`Conv2dConnection` object -uses the :code:`_conv2d_connection_update` version. \ No newline at end of file +uses the :code:`_conv2d_connection_update` version. + +If using a MulticompartmentConneciton, you can add a learning rule to a specific feature. Note that only +:code:`NoOp`, :code:`PostPre`, :code:`MSTDP`, :code:`MSTDPET` are supported, and located at +bindsnet.learning.MCC_learning. Below is an example of how to apply a PostPre learning rule to a weight function. +Note that the bias does not have a learning rule, so it will remain static. + +.. code-block:: python + + from bindsnet.network.nodes import Input, LIFNodes + from bindsnet.network.topology import MulticompartmentConnection + from bindsnet.learning.MCC_learning import PostPre + + # Create two populations of neurons, one to act as the "source" + # population, and the other, the "target population". + # Neurons involved in certain learning rules must record synaptic + # traces, a vector of short-term memories of the last emitted spikes. + source_layer = Input(n=100, traces=True) + target_layer = LIFNodes(n=1000, traces=True) + + # Create 'pipeline' of features that spikes will pass through + weights = Weight(name='weight_feature', value=torch.rand(100, 1000), + learning_rule=PostPre, nu=(1e-4, 1e-2)) + bias = Bias(name='bias_feature', value=torch.rand(100, 1000)) + + # Connect the two layers. + connection = MulticompartmentConnection( + source=source_layer, target=target_layer, + pipeline=[weights, bias]) + ) \ No newline at end of file diff --git a/examples/mnist/MCC_reservoir.py b/examples/mnist/MCC_reservoir.py new file mode 100644 index 00000000..89145d6f --- /dev/null +++ b/examples/mnist/MCC_reservoir.py @@ -0,0 +1,309 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import argparse +import matplotlib.pyplot as plt + +from torchvision import transforms +from tqdm import tqdm + +from bindsnet.analysis.plotting import ( + plot_input, + plot_spikes, + plot_voltages, + plot_weights, +) +from bindsnet.datasets import MNIST +from bindsnet.encoding import PoissonEncoder +from bindsnet.network import Network +from bindsnet.network.nodes import Input +from bindsnet.network.topology_features import Probability, Weight, Mask + +# Build a simple two-layer, input-output network. +from bindsnet.network.monitors import Monitor +from bindsnet.network.nodes import LIFNodes +from bindsnet.network.topology import MulticompartmentConnection +from bindsnet.utils import get_square_weights + + +parser = argparse.ArgumentParser() +parser.add_argument("--seed", type=int, default=0) +parser.add_argument("--n_neurons", type=int, default=500) +parser.add_argument("--n_epochs", type=int, default=100) +parser.add_argument("--examples", type=int, default=500) +parser.add_argument("--n_workers", type=int, default=-1) +parser.add_argument("--time", type=int, default=250) +parser.add_argument("--dt", type=int, default=1.0) +parser.add_argument("--intensity", type=float, default=64) +parser.add_argument("--progress_interval", type=int, default=10) +parser.add_argument("--update_interval", type=int, default=250) +parser.add_argument("--plot", dest="plot", action="store_true") +parser.add_argument("--gpu", dest="gpu", action="store_true") +parser.set_defaults(plot=False, gpu=True, train=True) + +args = parser.parse_args() + +seed = args.seed +n_neurons = args.n_neurons +n_epochs = args.n_epochs +examples = args.examples +n_workers = args.n_workers +time = args.time +dt = args.dt +intensity = args.intensity +progress_interval = args.progress_interval +update_interval = args.update_interval +train = args.train +plot = args.plot +gpu = args.gpu + +np.random.seed(seed) +torch.cuda.manual_seed_all(seed) +torch.manual_seed(seed) + +# Sets up Gpu use +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if gpu and torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) +else: + torch.manual_seed(seed) + device = "cpu" + if gpu: + gpu = False +torch.set_num_threads(os.cpu_count() - 1) +print("Running on Device = ", device) + + +### Base model ### +model = Network() +model.to(device) + + +### Layers ### +input_l = Input(n=784, shape=(1, 28, 28), traces=True) +output_l = LIFNodes( + n=n_neurons, thresh=-52 + np.random.randn(n_neurons).astype(float), traces=True +) + +model.add_layer(input_l, name="X") +model.add_layer(output_l, name="Y") + + +### Connections ### +p = torch.rand(input_l.n, output_l.n) +d = torch.rand(input_l.n, output_l.n) / 5 +w = torch.sign(torch.randint(-1, +2, (input_l.n, output_l.n))) +prob_feature = Probability(name="input_prob_feature", value=p) +weight_feature = Weight(name="input_weight_feature", value=w) +pipeline = [prob_feature, weight_feature] +input_con = MulticompartmentConnection( + source=input_l, + target=output_l, + device=device, + pipeline=pipeline, +) + +p = torch.rand(output_l.n, output_l.n) +d = torch.rand(output_l.n, output_l.n) / 5 +w = torch.sign(torch.randint(-1, +2, (output_l.n, output_l.n))) +prob_feature = Probability(name="recc_prob_feature", value=p) +weight_feature = Weight(name="recc_weight_feature", value=w) +pipeline = [prob_feature, weight_feature] +recurrent_con = MulticompartmentConnection( + source=output_l, + target=output_l, + device=device, + pipeline=pipeline, +) + +model.add_connection(input_con, source="X", target="Y") +model.add_connection(recurrent_con, source="Y", target="Y") + +# Directs network to GPU +if gpu: + model.to("cuda") + +### MNIST ### +dataset = MNIST( + PoissonEncoder(time=time, dt=dt), + None, + root=os.path.join("../../test", "..", "data", "MNIST"), + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)] + ), +) + + +### Monitor setup ### +inpt_axes = None +inpt_ims = None +spike_axes = None +spike_ims = None +weights_im = None +weights_im2 = None +voltage_ims = None +voltage_axes = None +spikes = {} +voltages = {} +for l in model.layers: + spikes[l] = Monitor(model.layers[l], ["s"], time=time, device=device) + model.add_monitor(spikes[l], name="%s_spikes" % l) + +voltages = {"Y": Monitor(model.layers["Y"], ["v"], time=time, device=device)} +model.add_monitor(voltages["Y"], name="Y_voltages") + + +### Running model on MNIST ### + +# Create a dataloader to iterate and batch data +dataloader = torch.utils.data.DataLoader( + dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True +) + +n_iters = examples +training_pairs = [] +pbar = tqdm(enumerate(dataloader)) +for i, dataPoint in pbar: + if i > n_iters: + break + + # Extract & resize the MNIST samples image data for training + # int(time / dt) -> length of spike train + # 28 x 28 -> size of sample + datum = dataPoint["encoded_image"].view(int(time / dt), 1, 1, 28, 28).to(device) + label = dataPoint["label"] + pbar.set_description_str("Train progress: (%d / %d)" % (i, n_iters)) + + # Run network on sample image + model.run(inputs={"X": datum}, time=time, input_time_dim=1, reward=1.0) + training_pairs.append([spikes["Y"].get("s").sum(0), label]) + + # Plot spiking activity using monitors + if plot: + inpt_axes, inpt_ims = plot_input( + dataPoint["image"].view(28, 28), + datum.view(int(time / dt), 784).sum(0).view(28, 28), + label=label, + axes=inpt_axes, + ims=inpt_ims, + ) + spike_ims, spike_axes = plot_spikes( + {layer: spikes[layer].get("s").view(time, -1) for layer in spikes}, + axes=spike_axes, + ims=spike_ims, + ) + voltage_ims, voltage_axes = plot_voltages( + {layer: voltages[layer].get("v").view(time, -1) for layer in voltages}, + ims=voltage_ims, + axes=voltage_axes, + ) + + plt.pause(1e-8) + model.reset_state_variables() + + +### Classification ### +# Define logistic regression model using PyTorch. +# These neurons will take the reservoirs output as its input, and be trained to classify the images. +class NN(nn.Module): + def __init__(self, input_size, num_classes): + super(NN, self).__init__() + # h = int(input_size/2) + self.linear_1 = nn.Linear(input_size, num_classes) + # self.linear_1 = nn.Linear(input_size, h) + # self.linear_2 = nn.Linear(h, num_classes) + + def forward(self, x): + out = torch.sigmoid(self.linear_1(x.float().view(-1))) + # out = torch.sigmoid(self.linear_2(out)) + return out + + +# Create and train logistic regression model on reservoir outputs. +learning_model = NN(n_neurons, 10).to(device) +criterion = torch.nn.MSELoss(reduction="sum") +optimizer = torch.optim.SGD(learning_model.parameters(), lr=1e-4, momentum=0.9) + +# Training the Model +print("\n Training the read out") +pbar = tqdm(enumerate(range(n_epochs))) +for epoch, _ in pbar: + avg_loss = 0 + + # Extract spike outputs from reservoir for a training sample + # i -> Loop index + # s -> Reservoir output spikes + # l -> Image label + for i, (s, l) in enumerate(training_pairs): + # Reset gradients to 0 + optimizer.zero_grad() + + # Run spikes through logistic regression model + outputs = learning_model(s) + + # Calculate MSE + label = torch.zeros(1, 1, 10).float().to(device) + label[0, 0, l] = 1.0 + loss = criterion(outputs.view(1, 1, -1), label) + avg_loss += loss.data + + # Optimize parameters + loss.backward() + optimizer.step() + + pbar.set_description_str( + "Epoch: %d/%d, Loss: %.4f" + % (epoch + 1, n_epochs, avg_loss / len(training_pairs)) + ) + +# Run same simulation on reservoir with testing data instead of training data +# (see training section for intuition) +n_iters = examples +test_pairs = [] +pbar = tqdm(enumerate(dataloader)) +for i, dataPoint in pbar: + if i > n_iters: + break + datum = dataPoint["encoded_image"].view(int(time / dt), 1, 1, 28, 28).to(device) + label = dataPoint["label"] + pbar.set_description_str("Testing progress: (%d / %d)" % (i, n_iters)) + + model.run(inputs={"X": datum}, time=time, input_time_dim=1) + test_pairs.append([spikes["Y"].get("s").sum(0), label]) + + if plot: + inpt_axes, inpt_ims = plot_input( + dataPoint["image"].view(28, 28), + datum.view(time, 784).sum(0).view(28, 28), + label=label, + axes=inpt_axes, + ims=inpt_ims, + ) + spike_ims, spike_axes = plot_spikes( + {layer: spikes[layer].get("s").view(time, -1) for layer in spikes}, + axes=spike_axes, + ims=spike_ims, + ) + voltage_ims, voltage_axes = plot_voltages( + {layer: voltages[layer].get("v").view(time, -1) for layer in voltages}, + ims=voltage_ims, + axes=voltage_axes, + ) + + plt.pause(1e-8) + model.reset_state_variables() + +# Test learning model with previously trained logistic regression classifier +correct, total = 0, 0 +for s, label in test_pairs: + outputs = learning_model(s) + _, predicted = torch.max(outputs.data.unsqueeze(0), 1) + total += 1 + correct += int(predicted == label.long().to(device)) + +print( + "\n Accuracy of the model on %d test images: %.2f %%" + % (n_iters, 100 * correct / total) +) diff --git a/examples/mnist/reservoir.py b/examples/mnist/reservoir.py index e61b17cc..9946fc32 100644 --- a/examples/mnist/reservoir.py +++ b/examples/mnist/reservoir.py @@ -101,7 +101,7 @@ dataset = MNIST( PoissonEncoder(time=time, dt=dt), None, - root=os.path.join("..", "data", "MNIST"), + root=os.path.join("..", "..", "data", "MNIST"), download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]