From 12c5f41c1d9f3ee147b8d59f7ec7d9195ee278a6 Mon Sep 17 00:00:00 2001 From: Ridger <57249513+ridgerchu@users.noreply.github.com> Date: Thu, 7 Jul 2022 17:00:43 +0800 Subject: [PATCH 1/8] Add new feature 'probe' --- snntorch/functional/probe.py | 252 +++++++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 snntorch/functional/probe.py diff --git a/snntorch/functional/probe.py b/snntorch/functional/probe.py new file mode 100644 index 00000000..199ec58c --- /dev/null +++ b/snntorch/functional/probe.py @@ -0,0 +1,252 @@ +import torch +from torch import nn +from typing import Callable, Any + + +def unpack_len1_tuple(x: tuple or torch.Tensor): + if isinstance(x, tuple) and x.__len__() == 1: + return x[0] + else: + return x + + +class BaseMonitor: + def __init__(self): + self.hooks = [] + self.monitored_layers = [] + self.records = [] + self.name_records_index = {} + self._enable = True + + def __getitem__(self, i): + if isinstance(i, int): + return self.records[i] + elif isinstance(i, str): + y = [] + for index in self.name_records_index[i]: + y.append(self.records[index]) + return y + else: + raise ValueError(i) + + def clear_recorded_data(self): + self.records.clear() + for k, v in self.name_records_index.items(): + v.clear() + + def enable(self): + self._enable = True + + def disable(self): + self._enable = False + + def is_enable(self): + return self._enable + + def remove_hooks(self): + for hook in self.hooks: + hook.remove() + + def __del__(self): + self.remove_hooks() + + +class OutputMonitor(BaseMonitor): + def __init__(self, net: nn.Module, instance: Any or tuple = None, function_on_output: Callable = lambda x: x): + super().__init__() + self.function_on_output = function_on_output + for name, m in net.named_modules(): + if isinstance(m, instance): + self.monitored_layers.append(name) + self.name_records_index[name] = [] + self.hooks.append(m.register_forward_hook(self.create_hook(name))) + + def create_hook(self, name): + def hook(m, x, y): + if self.is_enable(): + self.name_records_index[name].append(self.records.__len__()) + self.records.append(self.function_on_output(unpack_len1_tuple(y))) + + return hook + +class InputMonitor(BaseMonitor): + def __init__(self, net: nn.Module, instance: Any or tuple = None, function_on_input: Callable = lambda x: x): + super().__init__() + self.function_on_input = function_on_input + for name, m in net.named_modules(): + if isinstance(m, instance): + self.monitored_layers.append(name) + self.name_records_index[name] = [] + self.hooks.append(m.register_forward_hook(self.create_hook(name))) + + def create_hook(self, name): + def hook(m, x, y): + if self.is_enable(): + self.name_records_index[name].append(self.records.__len__()) + self.records.append(self.function_on_input(unpack_len1_tuple(x))) + + return hook + +class AttributeMonitor(BaseMonitor): + def __init__(self, attribute_name: str, pre_forward: bool, net: nn.Module, instance: Any or tuple = None, + function_on_attribute: Callable = lambda x: x): + super().__init__() + self.attribute_name = attribute_name + self.function_on_attribute = function_on_attribute + + for name, m in net.named_modules(): + if isinstance(m, instance): + self.monitored_layers.append(name) + self.name_records_index[name] = [] + if pre_forward: + self.hooks.append( + m.register_forward_pre_hook(self.create_hook(name)) + ) + else: + self.hooks.append( + m.register_forward_hook(self.create_hook(name)) + ) + + def create_hook(self, name): + def hook(m, x, y): + if self.is_enable(): + self.name_records_index[name].append(self.records.__len__()) + self.records.append(self.function_on_attribute(m.__getattr__(self.attribute_name))) + + return hook + +class GradInputMonitor(BaseMonitor): + def __init__(self, net: nn.Module, instance: Any or tuple = None, function_on_grad_input: Callable = lambda x: x): + super().__init__() + self.function_on_grad_input = function_on_grad_input + + for name, m in net.named_modules(): + if isinstance(m, instance): + self.monitored_layers.append(name) + self.name_records_index[name] = [] + if torch.__version__ >= torch.torch_version.TorchVersion('1.8.0'): + self.hooks.append(m.register_full_backward_hook(self.create_hook(name))) + else: + self.hooks.append(m.register_backward_hook(self.create_hook(name))) + + def create_hook(self, name): + def hook(m, grad_input, grad_output): + if self.is_enable(): + self.name_records_index[name].append(self.records.__len__()) + self.records.append(self.function_on_grad_input(unpack_len1_tuple(grad_input))) + + return hook + +class GradOutputMonitor(BaseMonitor): + def __init__(self, net: nn.Module, instance: Any or tuple = None, function_on_grad_output: Callable = lambda x: x): + """ + * :ref:`API in English ` + .. _GradOutputMonitor-cn: + :param net: 一个神经网络 + :type net: nn.Module + :param instance: 设置监视器的类型。若为 ``None`` 则表示类型为 ``type(net)`` + :type instance: Any or tuple + :param function_on_grad_output: 作用于被监控的模块输出的输出的的梯度的函数 + :type function_on_grad_output: Callable + 对 ``net`` 中所有类型为 ``instance`` 的模块的输出的梯度使用 ``function_on_grad_output`` 作用后,记录到类型为 `list`` 的 ``self.records`` 中。 + 可以通过 ``self.enable()`` 和 ``self.disable()`` 来启用或停用这个监视器。 + 可以通过 ``self.clear_recorded_data()`` 来清除已经记录的数据。 + + 阅读监视器的教程以获得更多信息。 + 示例代码: + .. code-block:: python + import torch + import torch.nn as nn + from spikingjelly.activation_based import monitor, neuron, functional, layer + class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = layer.Linear(8, 4) + self.sn1 = neuron.IFNode() + self.fc2 = layer.Linear(4, 2) + self.sn2 = neuron.IFNode() + functional.set_step_mode(self, 'm') + def forward(self, x_seq: torch.Tensor): + x_seq = self.fc1(x_seq) + x_seq = self.sn1(x_seq) + x_seq = self.fc2(x_seq) + x_seq = self.sn2(x_seq) + return x_seq + net = Net() + for param in net.parameters(): + param.data.abs_() + mtor = monitor.GradOutputMonitor(net, instance=neuron.IFNode) + net(torch.rand([1, 8])).sum().backward() + print(f'mtor.records={mtor.records}') + # mtor.records=[tensor([[1., 1.]]), tensor([[0.1372, 0.1081, 0.0880, 0.1089]])] + print(f'mtor[0]={mtor[0]}') + # mtor[0]=tensor([[1., 1.]]) + print(f'mtor.monitored_layers={mtor.monitored_layers}') + # mtor.monitored_layers=['sn1', 'sn2'] + print(f"mtor['sn1']={mtor['sn1']}") + # mtor['sn1']=[tensor([[0.1372, 0.1081, 0.0880, 0.1089]])] + * :ref:`中文 API ` + .. _GradOutputMonitor-en: + :param net: a network + :type net: nn.Module + :param instance: the instance of modules to be monitored. If ``None``, it will be regarded as ``type(net)`` + :type instance: Any or tuple + :param function_on_grad_output: the function that applies on the grad of monitored modules' inputs + :type function_on_grad_output: Callable + Applies ``function_on_grad_output`` on grad of outputs of all modules whose instances are ``instance`` in ``net``, and records + the data into ``self.records``, which is a ``list``. + Call ``self.enable()`` or ``self.disable()`` to enable or disable the monitor. + Call ``self.clear_recorded_data()`` to clear the recorded data. + Refer to the tutorial about the monitor for more details. + Codes example: + .. code-block:: python + import torch + import torch.nn as nn + from spikingjelly.activation_based import monitor, neuron, functional, layer + class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = layer.Linear(8, 4) + self.sn1 = neuron.IFNode() + self.fc2 = layer.Linear(4, 2) + self.sn2 = neuron.IFNode() + functional.set_step_mode(self, 'm') + def forward(self, x_seq: torch.Tensor): + x_seq = self.fc1(x_seq) + x_seq = self.sn1(x_seq) + x_seq = self.fc2(x_seq) + x_seq = self.sn2(x_seq) + return x_seq + net = Net() + for param in net.parameters(): + param.data.abs_() + mtor = monitor.GradOutputMonitor(net, instance=neuron.IFNode) + net(torch.rand([1, 8])).sum().backward() + print(f'mtor.records={mtor.records}') + # mtor.records=[tensor([[1., 1.]]), tensor([[0.1372, 0.1081, 0.0880, 0.1089]])] + print(f'mtor[0]={mtor[0]}') + # mtor[0]=tensor([[1., 1.]]) + print(f'mtor.monitored_layers={mtor.monitored_layers}') + # mtor.monitored_layers=['sn1', 'sn2'] + print(f"mtor['sn1']={mtor['sn1']}") + # mtor['sn1']=[tensor([[0.1372, 0.1081, 0.0880, 0.1089]])] + """ + + super().__init__() + self.function_on_grad_output = function_on_grad_output + for name, m in net.named_modules(): + if isinstance(m, instance): + self.monitored_layers.append(name) + self.name_records_index[name] = [] + if torch.__version__ >= torch.torch_version.TorchVersion('1.8.0'): + self.hooks.append(m.register_full_backward_hook(self.create_hook(name))) + else: + self.hooks.append(m.register_backward_hook(self.create_hook(name))) + + def create_hook(self, name): + def hook(m, grad_input, grad_output): + if self.is_enable(): + self.name_records_index[name].append(self.records.__len__()) + self.records.append(self.function_on_grad_output(unpack_len1_tuple(grad_output))) + + return hook From f1c8b3714c05f403b7c64025af8ad1416c58f6e3 Mon Sep 17 00:00:00 2001 From: Ridger <57249513+ridgerchu@users.noreply.github.com> Date: Thu, 7 Jul 2022 17:15:04 +0800 Subject: [PATCH 2/8] Update probe.py --- snntorch/functional/probe.py | 93 ------------------------------------ 1 file changed, 93 deletions(-) diff --git a/snntorch/functional/probe.py b/snntorch/functional/probe.py index 199ec58c..9216ad7c 100644 --- a/snntorch/functional/probe.py +++ b/snntorch/functional/probe.py @@ -139,99 +139,6 @@ def hook(m, grad_input, grad_output): class GradOutputMonitor(BaseMonitor): def __init__(self, net: nn.Module, instance: Any or tuple = None, function_on_grad_output: Callable = lambda x: x): - """ - * :ref:`API in English ` - .. _GradOutputMonitor-cn: - :param net: 一个神经网络 - :type net: nn.Module - :param instance: 设置监视器的类型。若为 ``None`` 则表示类型为 ``type(net)`` - :type instance: Any or tuple - :param function_on_grad_output: 作用于被监控的模块输出的输出的的梯度的函数 - :type function_on_grad_output: Callable - 对 ``net`` 中所有类型为 ``instance`` 的模块的输出的梯度使用 ``function_on_grad_output`` 作用后,记录到类型为 `list`` 的 ``self.records`` 中。 - 可以通过 ``self.enable()`` 和 ``self.disable()`` 来启用或停用这个监视器。 - 可以通过 ``self.clear_recorded_data()`` 来清除已经记录的数据。 - - 阅读监视器的教程以获得更多信息。 - 示例代码: - .. code-block:: python - import torch - import torch.nn as nn - from spikingjelly.activation_based import monitor, neuron, functional, layer - class Net(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = layer.Linear(8, 4) - self.sn1 = neuron.IFNode() - self.fc2 = layer.Linear(4, 2) - self.sn2 = neuron.IFNode() - functional.set_step_mode(self, 'm') - def forward(self, x_seq: torch.Tensor): - x_seq = self.fc1(x_seq) - x_seq = self.sn1(x_seq) - x_seq = self.fc2(x_seq) - x_seq = self.sn2(x_seq) - return x_seq - net = Net() - for param in net.parameters(): - param.data.abs_() - mtor = monitor.GradOutputMonitor(net, instance=neuron.IFNode) - net(torch.rand([1, 8])).sum().backward() - print(f'mtor.records={mtor.records}') - # mtor.records=[tensor([[1., 1.]]), tensor([[0.1372, 0.1081, 0.0880, 0.1089]])] - print(f'mtor[0]={mtor[0]}') - # mtor[0]=tensor([[1., 1.]]) - print(f'mtor.monitored_layers={mtor.monitored_layers}') - # mtor.monitored_layers=['sn1', 'sn2'] - print(f"mtor['sn1']={mtor['sn1']}") - # mtor['sn1']=[tensor([[0.1372, 0.1081, 0.0880, 0.1089]])] - * :ref:`中文 API ` - .. _GradOutputMonitor-en: - :param net: a network - :type net: nn.Module - :param instance: the instance of modules to be monitored. If ``None``, it will be regarded as ``type(net)`` - :type instance: Any or tuple - :param function_on_grad_output: the function that applies on the grad of monitored modules' inputs - :type function_on_grad_output: Callable - Applies ``function_on_grad_output`` on grad of outputs of all modules whose instances are ``instance`` in ``net``, and records - the data into ``self.records``, which is a ``list``. - Call ``self.enable()`` or ``self.disable()`` to enable or disable the monitor. - Call ``self.clear_recorded_data()`` to clear the recorded data. - Refer to the tutorial about the monitor for more details. - Codes example: - .. code-block:: python - import torch - import torch.nn as nn - from spikingjelly.activation_based import monitor, neuron, functional, layer - class Net(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = layer.Linear(8, 4) - self.sn1 = neuron.IFNode() - self.fc2 = layer.Linear(4, 2) - self.sn2 = neuron.IFNode() - functional.set_step_mode(self, 'm') - def forward(self, x_seq: torch.Tensor): - x_seq = self.fc1(x_seq) - x_seq = self.sn1(x_seq) - x_seq = self.fc2(x_seq) - x_seq = self.sn2(x_seq) - return x_seq - net = Net() - for param in net.parameters(): - param.data.abs_() - mtor = monitor.GradOutputMonitor(net, instance=neuron.IFNode) - net(torch.rand([1, 8])).sum().backward() - print(f'mtor.records={mtor.records}') - # mtor.records=[tensor([[1., 1.]]), tensor([[0.1372, 0.1081, 0.0880, 0.1089]])] - print(f'mtor[0]={mtor[0]}') - # mtor[0]=tensor([[1., 1.]]) - print(f'mtor.monitored_layers={mtor.monitored_layers}') - # mtor.monitored_layers=['sn1', 'sn2'] - print(f"mtor['sn1']={mtor['sn1']}") - # mtor['sn1']=[tensor([[0.1372, 0.1081, 0.0880, 0.1089]])] - """ - super().__init__() self.function_on_grad_output = function_on_grad_output for name, m in net.named_modules(): From 9e07e4f21f3101d3509a760eed002260f917bf4e Mon Sep 17 00:00:00 2001 From: Ridger <57249513+ridgerchu@users.noreply.github.com> Date: Thu, 7 Jul 2022 18:18:15 +0800 Subject: [PATCH 3/8] Update __init__.py --- snntorch/functional/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/snntorch/functional/__init__.py b/snntorch/functional/__init__.py index d385d047..65f8aa48 100644 --- a/snntorch/functional/__init__.py +++ b/snntorch/functional/__init__.py @@ -2,3 +2,4 @@ from .acc import * from .loss import * from .reg import * +from .probe import * From 07066dc0ace78351786f0640413a1163a40d04f5 Mon Sep 17 00:00:00 2001 From: Ridger <57249513+ridgerchu@users.noreply.github.com> Date: Tue, 12 Jul 2022 11:24:36 +0800 Subject: [PATCH 4/8] add probe to doc. --- docs/snntorch.functional.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/snntorch.functional.rst b/docs/snntorch.functional.rst index 49094c5f..6b9f0a40 100644 --- a/docs/snntorch.functional.rst +++ b/docs/snntorch.functional.rst @@ -56,4 +56,12 @@ State Quantization .. automodule:: snntorch.functional.quant :members: :undoc-members: - :show-inheritance: \ No newline at end of file + :show-inheritance: + +Probe +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. automodule:: snntorch.functional.probe + :members: + :undoc-members: + :show-inheritance: From 1959a869b4d4c156bf7eb7220b30da31cccf5a1b Mon Sep 17 00:00:00 2001 From: Ridger <57249513+ridgerchu@users.noreply.github.com> Date: Tue, 12 Jul 2022 11:57:55 +0800 Subject: [PATCH 5/8] Add doc in InputMonitor and OutputMonitor. --- snntorch/functional/probe.py | 94 ++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/snntorch/functional/probe.py b/snntorch/functional/probe.py index 9216ad7c..5547586c 100644 --- a/snntorch/functional/probe.py +++ b/snntorch/functional/probe.py @@ -52,6 +52,53 @@ def __del__(self): class OutputMonitor(BaseMonitor): + ''' + A monitor to record the output spikes of each specific neuron layer (e.g. Leaky) in a network. + all of the output data will be recorded in ''self.record'' with the python data type ''list''. + Call ``self.enable()`` or ``self.disable()`` to enable or disable the monitor. + Call ``self.clear_recorded_data()`` to clear the recorded data. + + Example:: + import snntorch + from snntorch.functional import probe + from torch import nn + import torch + class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(8, 4) + self.sn1 = snntorch.Leaky() + self.fc2 = nn.Linear(4, 2) + self.sn2 = snntorch.Leaky() + + def forward(self, x_seq: torch.Tensor): + x_seq = self.fc1(x_seq) + x_seq = self.sn1(x_seq) + x_seq = self.fc2(x_seq) + x_seq = self.sn2(x_seq) + return x_seq + + net = Net() + for param in net.parameters(): + #keeps all parameter in positive to make sure spike emiting in network. + param.data.abs_() + + mtor = probe.OutputMonitor(net, instance=snntorch.Leaky()) + + with torch.no_grad(): + y = net(torch.rand([1, 8])) + print(f'mtor.records={mtor.records}') + print(f'mtor[0]={mtor[0]}') + print(f'mtor.monitored_layers={mtor.monitored_layers}') + print(f"mtor['sn1']={mtor['sn1']}") + + :param net: a PyTorch network + :type net: nn.Module + :param instance: the instance of modules to be monitored. If ``None``, it will be regarded as ``type(net)`` + :type instance: Any or tuple + :param function_on_output: the function that applies on the monitored modules' outputs + :type function_on_output: Callable, optional + ''' def __init__(self, net: nn.Module, instance: Any or tuple = None, function_on_output: Callable = lambda x: x): super().__init__() self.function_on_output = function_on_output @@ -70,6 +117,53 @@ def hook(m, x, y): return hook class InputMonitor(BaseMonitor): + ''' + A monitor to record the input of each specific neuron layer (e.g. Leaky) in a network. + all of the input data will be recorded in ''self.record'' with the python data type ''list''. + Call ``self.enable()`` or ``self.disable()`` to enable or disable the monitor. + Call ``self.clear_recorded_data()`` to clear the recorded data. + + Example:: + from snntorch.functional import probe + import snntorch + import torch + from torch import nn + class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(8, 4) + self.sn1 = snntorch.Leaky() + self.fc2 = nn.Linear(4, 2) + self.sn2 = snntorch.Leaky() + + def forward(self, x_seq: torch.Tensor): + x_seq = self.fc1(x_seq) + x_seq = self.sn1(x_seq) + x_seq = self.fc2(x_seq) + x_seq = self.sn2(x_seq) + return x_seq + + net = Net() + for param in net.parameters(): + #keeps all parameter in positive to make sure spike emiting in network. + param.data.abs_() + + mtor = probe.InputMonitor(net, instance=snntorch.Leaky()) + + with torch.no_grad(): + y = net(torch.rand([1, 8])) + print(f'mtor.records={mtor.records}') + print(f'mtor[0]={mtor[0]}') + print(f'mtor.monitored_layers={mtor.monitored_layers}') + print(f"mtor['sn1']={mtor['sn1']}") + + :param net: a PyTorch network + :type net: nn.Module + :param instance: the instance of modules to be monitored. If ``None``, it will be regarded as ``type(net)`` + :type instance: Any or tuple + :param function_on_input: the function that applies on the monitored modules' inputs + :type function_on_input: Callable, optional + ''' def __init__(self, net: nn.Module, instance: Any or tuple = None, function_on_input: Callable = lambda x: x): super().__init__() self.function_on_input = function_on_input From e883dfb762ecf531365e71d8b7d0e18580d13223 Mon Sep 17 00:00:00 2001 From: Ridger <57249513+ridgerchu@users.noreply.github.com> Date: Tue, 26 Jul 2022 09:57:36 +0800 Subject: [PATCH 6/8] Add doc for all functions. --- snntorch/functional/probe.py | 146 +++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/snntorch/functional/probe.py b/snntorch/functional/probe.py index 5547586c..ac533157 100644 --- a/snntorch/functional/probe.py +++ b/snntorch/functional/probe.py @@ -182,6 +182,58 @@ def hook(m, x, y): return hook class AttributeMonitor(BaseMonitor): + ''' + A monitor to record the attribute (e.g. membrane potential) of a specific neuron layer (e.g. Leaky) in a network. + You could specify the attribute name as the first parameter of this function. + all of the input data will be recorded in ''self.record'' with the python data type ''list''. + Call ``self.enable()`` or ``self.disable()`` to enable or disable the monitor. + Call ``self.clear_recorded_data()`` to clear the recorded data. + + Example:: + from snntorch.functional import probe + import snntorch + import torch + from torch import nn + class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(8, 4) + self.sn1 = snntorch.Leaky() + self.fc2 = nn.Linear(4, 2) + self.sn2 = snntorch.Leaky() + + def forward(self, x_seq: torch.Tensor): + x_seq = self.fc1(x_seq) + x_seq = self.sn1(x_seq) + x_seq = self.fc2(x_seq) + x_seq = self.sn2(x_seq) + return x_seq + + net = Net() + for param in net.parameters(): + #keeps all parameter in positive to make sure spike emiting in network. + param.data.abs_() + + mtor = probe.AttributeMonitor('mem', False, net, instance=snntorch.Leaky()) + + with torch.no_grad(): + y = net(torch.rand([1, 8])) + print(f'mtor.records={mtor.records}') + print(f'mtor[0]={mtor[0]}') + print(f'mtor.monitored_layers={mtor.monitored_layers}') + print(f"mtor['sn1']={mtor['sn1']}") + + :param attribute_name: the attribute's name of probed neuron layer + :type net: str + :param pre_forward: If "True", recording the attribute value before feed forward, otherwise recording the value after feed forward. + :type pre_forward: bool + :param net: a PyTorch network + :type net: nn.Module + :param instance: the instance of modules to be monitored. If ``None``, it will be regarded as ``type(net)`` + :type instance: Any or tuple + :param function_on_attribute: the function that applies on the monitored modules' inputs + :type function_on_attribute: Callable, optional + ''' def __init__(self, attribute_name: str, pre_forward: bool, net: nn.Module, instance: Any or tuple = None, function_on_attribute: Callable = lambda x: x): super().__init__() @@ -210,6 +262,53 @@ def hook(m, x, y): return hook class GradInputMonitor(BaseMonitor): + ''' + A monitor to record the input gradient of each specific neuron layer (e.g. Leaky) in a network. + all of the input data will be recorded in ''self.record'' with the python data type ''list''. + Call ``self.enable()`` or ``self.disable()`` to enable or disable the monitor. + Call ``self.clear_recorded_data()`` to clear the recorded data. + + Example:: + from snntorch.functional import probe + import snntorch + import torch + from torch import nn + class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(8, 4) + self.sn1 = snntorch.Leaky() + self.fc2 = nn.Linear(4, 2) + self.sn2 = snntorch.Leaky() + + def forward(self, x_seq: torch.Tensor): + x_seq = self.fc1(x_seq) + x_seq = self.sn1(x_seq) + x_seq = self.fc2(x_seq) + x_seq = self.sn2(x_seq) + return x_seq + + net = Net() + for param in net.parameters(): + #keeps all parameter in positive to make sure spike emiting in network. + param.data.abs_() + + mtor = probe.GradInputMonitor(net, instance=snntorch.Leaky()) + + with torch.no_grad(): + y = net(torch.rand([1, 8])) + print(f'mtor.records={mtor.records}') + print(f'mtor[0]={mtor[0]}') + print(f'mtor.monitored_layers={mtor.monitored_layers}') + print(f"mtor['sn1']={mtor['sn1']}") + + :param net: a PyTorch network + :type net: nn.Module + :param instance: the instance of modules to be monitored. If ``None``, it will be regarded as ``type(net)`` + :type instance: Any or tuple + :param function_on_grad_input: the function that applies on the monitored modules' inputs + :type function_on_grad_input: Callable, optional + ''' def __init__(self, net: nn.Module, instance: Any or tuple = None, function_on_grad_input: Callable = lambda x: x): super().__init__() self.function_on_grad_input = function_on_grad_input @@ -232,6 +331,53 @@ def hook(m, grad_input, grad_output): return hook class GradOutputMonitor(BaseMonitor): + ''' + A monitor to record the output gradient of each specific neuron layer (e.g. Leaky) in a network. + all of the input data will be recorded in ''self.record'' with the python data type ''list''. + Call ``self.enable()`` or ``self.disable()`` to enable or disable the monitor. + Call ``self.clear_recorded_data()`` to clear the recorded data. + + Example:: + from snntorch.functional import probe + import snntorch + import torch + from torch import nn + class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(8, 4) + self.sn1 = snntorch.Leaky() + self.fc2 = nn.Linear(4, 2) + self.sn2 = snntorch.Leaky() + + def forward(self, x_seq: torch.Tensor): + x_seq = self.fc1(x_seq) + x_seq = self.sn1(x_seq) + x_seq = self.fc2(x_seq) + x_seq = self.sn2(x_seq) + return x_seq + + net = Net() + for param in net.parameters(): + #keeps all parameter in positive to make sure spike emiting in network. + param.data.abs_() + + mtor = probe.GradOutputMonitor(net, instance=snntorch.Leaky()) + + with torch.no_grad(): + y = net(torch.rand([1, 8])) + print(f'mtor.records={mtor.records}') + print(f'mtor[0]={mtor[0]}') + print(f'mtor.monitored_layers={mtor.monitored_layers}') + print(f"mtor['sn1']={mtor['sn1']}") + + :param net: a PyTorch network + :type net: nn.Module + :param instance: the instance of modules to be monitored. If ``None``, it will be regarded as ``type(net)`` + :type instance: Any or tuple + :param function_on_grad_output: the function that applies on the monitored modules' inputs + :type function_on_grad_output: Callable, optional + ''' def __init__(self, net: nn.Module, instance: Any or tuple = None, function_on_grad_output: Callable = lambda x: x): super().__init__() self.function_on_grad_output = function_on_grad_output From cb568446e281c54a15e374de9b0769b7f25ca79b Mon Sep 17 00:00:00 2001 From: Ridger <57249513+ridgerchu@users.noreply.github.com> Date: Tue, 26 Jul 2022 10:02:43 +0800 Subject: [PATCH 7/8] fix some bugs. --- snntorch/functional/probe.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/snntorch/functional/probe.py b/snntorch/functional/probe.py index ac533157..b774369f 100644 --- a/snntorch/functional/probe.py +++ b/snntorch/functional/probe.py @@ -182,7 +182,7 @@ def hook(m, x, y): return hook class AttributeMonitor(BaseMonitor): - ''' + ''' A monitor to record the attribute (e.g. membrane potential) of a specific neuron layer (e.g. Leaky) in a network. You could specify the attribute name as the first parameter of this function. all of the input data will be recorded in ''self.record'' with the python data type ''list''. @@ -225,7 +225,7 @@ def forward(self, x_seq: torch.Tensor): :param attribute_name: the attribute's name of probed neuron layer :type net: str - :param pre_forward: If "True", recording the attribute value before feed forward, otherwise recording the value after feed forward. + :param pre_forward: If ``True``, recording the attribute value before feed forward, otherwise recording the value after feed forward. :type pre_forward: bool :param net: a PyTorch network :type net: nn.Module @@ -262,7 +262,7 @@ def hook(m, x, y): return hook class GradInputMonitor(BaseMonitor): - ''' + ''' A monitor to record the input gradient of each specific neuron layer (e.g. Leaky) in a network. all of the input data will be recorded in ''self.record'' with the python data type ''list''. Call ``self.enable()`` or ``self.disable()`` to enable or disable the monitor. @@ -331,7 +331,7 @@ def hook(m, grad_input, grad_output): return hook class GradOutputMonitor(BaseMonitor): - ''' + ''' A monitor to record the output gradient of each specific neuron layer (e.g. Leaky) in a network. all of the input data will be recorded in ''self.record'' with the python data type ''list''. Call ``self.enable()`` or ``self.disable()`` to enable or disable the monitor. From 8b79b20c56ff1baae7617b04f71b58df06918009 Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Tue, 26 Jul 2022 16:20:02 +0800 Subject: [PATCH 8/8] update docstrings and examples for probe functions --- snntorch/functional/probe.py | 202 +++++++++++++++++++---------------- 1 file changed, 107 insertions(+), 95 deletions(-) diff --git a/snntorch/functional/probe.py b/snntorch/functional/probe.py index b774369f..c0a552c5 100644 --- a/snntorch/functional/probe.py +++ b/snntorch/functional/probe.py @@ -54,50 +54,53 @@ def __del__(self): class OutputMonitor(BaseMonitor): ''' A monitor to record the output spikes of each specific neuron layer (e.g. Leaky) in a network. - all of the output data will be recorded in ''self.record'' with the python data type ''list''. + All output data is recorded in ``self.record`` as data type ''list''. Call ``self.enable()`` or ``self.disable()`` to enable or disable the monitor. - Call ``self.clear_recorded_data()`` to clear the recorded data. + Call ``self.clear_recorded_data()`` to clear recorded data. Example:: - import snntorch + + import snntorch as snn from snntorch.functional import probe - from torch import nn + import torch + from torch import nn + class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(8, 4) - self.sn1 = snntorch.Leaky() + self.lif1 = snn.Leaky() self.fc2 = nn.Linear(4, 2) - self.sn2 = snntorch.Leaky() + self.lif2 = snn.Leaky() def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) - x_seq = self.sn1(x_seq) + x_seq = self.lif1(x_seq) x_seq = self.fc2(x_seq) - x_seq = self.sn2(x_seq) + x_seq = self.lif2(x_seq) return x_seq net = Net() - for param in net.parameters(): - #keeps all parameter in positive to make sure spike emiting in network. - param.data.abs_() - mtor = probe.OutputMonitor(net, instance=snntorch.Leaky()) + monitor = probe.OutputMonitor(net, instance=snntorch.Leaky()) with torch.no_grad(): y = net(torch.rand([1, 8])) - print(f'mtor.records={mtor.records}') - print(f'mtor[0]={mtor[0]}') - print(f'mtor.monitored_layers={mtor.monitored_layers}') - print(f"mtor['sn1']={mtor['sn1']}") + print(f'monitor.records={monitor.records}') + print(f'monitor[0]={monitor[0]}') + print(f'monitor.monitored_layers={monitor.monitored_layers}') + print(f"monitor['lif1']={monitor['lif1']}") - :param net: a PyTorch network + :param net: Network model (either wrapped in Sequential container or as a class) :type net: nn.Module - :param instance: the instance of modules to be monitored. If ``None``, it will be regarded as ``type(net)`` + + :param instance: Instance of modules to be monitored. If ``None``, defaults to ``type(net)`` :type instance: Any or tuple - :param function_on_output: the function that applies on the monitored modules' outputs + + :param function_on_output: Function that is applied to the monitored modules' outputs :type function_on_output: Callable, optional + ''' def __init__(self, net: nn.Module, instance: Any or tuple = None, function_on_output: Callable = lambda x: x): super().__init__() @@ -118,50 +121,51 @@ def hook(m, x, y): class InputMonitor(BaseMonitor): ''' - A monitor to record the input of each specific neuron layer (e.g. Leaky) in a network. - all of the input data will be recorded in ''self.record'' with the python data type ''list''. + A monitor to record the input of each neuron layer (e.g. Leaky) in a network. + All input data is recorded in ``self.record`` as data type ''list''. Call ``self.enable()`` or ``self.disable()`` to enable or disable the monitor. - Call ``self.clear_recorded_data()`` to clear the recorded data. + Call ``self.clear_recorded_data()`` to clear recorded data. Example:: + import snntorch as snn from snntorch.functional import probe - import snntorch + import torch from torch import nn + class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(8, 4) - self.sn1 = snntorch.Leaky() + self.lif1 = snn.Leaky() self.fc2 = nn.Linear(4, 2) - self.sn2 = snntorch.Leaky() + self.lif2 = snn.Leaky() def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) - x_seq = self.sn1(x_seq) + x_seq = self.lif1(x_seq) x_seq = self.fc2(x_seq) - x_seq = self.sn2(x_seq) + x_seq = self.lif2(x_seq) return x_seq net = Net() - for param in net.parameters(): - #keeps all parameter in positive to make sure spike emiting in network. - param.data.abs_() - mtor = probe.InputMonitor(net, instance=snntorch.Leaky()) + monitor = probe.InputMonitor(net, instance=snn.Leaky()) with torch.no_grad(): y = net(torch.rand([1, 8])) - print(f'mtor.records={mtor.records}') - print(f'mtor[0]={mtor[0]}') - print(f'mtor.monitored_layers={mtor.monitored_layers}') - print(f"mtor['sn1']={mtor['sn1']}") + print(f'monitor.records={monitor.records}') + print(f'monitor[0]={monitor[0]}') + print(f'monitor.monitored_layers={monitor.monitored_layers}') + print(f"monitor['lif1']={monitor['lif1']}") - :param net: a PyTorch network + :param net: Network model (either wrapped in Sequential container or as a class) :type net: nn.Module - :param instance: the instance of modules to be monitored. If ``None``, it will be regarded as ``type(net)`` + + :param instance: Instance of modules to be monitored. If ``None``, defaults to ``type(net)`` :type instance: Any or tuple - :param function_on_input: the function that applies on the monitored modules' inputs + + :param function_on_input: Function that is applied to the monitored modules' input :type function_on_input: Callable, optional ''' def __init__(self, net: nn.Module, instance: Any or tuple = None, function_on_input: Callable = lambda x: x): @@ -184,56 +188,60 @@ def hook(m, x, y): class AttributeMonitor(BaseMonitor): ''' A monitor to record the attribute (e.g. membrane potential) of a specific neuron layer (e.g. Leaky) in a network. - You could specify the attribute name as the first parameter of this function. - all of the input data will be recorded in ''self.record'' with the python data type ''list''. + The attribute name can be specified as the first argument of this function. + All attribute data is recorded in ``self.record`` as data type ''list''. Call ``self.enable()`` or ``self.disable()`` to enable or disable the monitor. - Call ``self.clear_recorded_data()`` to clear the recorded data. + Call ``self.clear_recorded_data()`` to clear recorded data. Example:: + import snntorch as snn from snntorch.functional import probe - import snntorch + import torch from torch import nn + class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(8, 4) - self.sn1 = snntorch.Leaky() + self.lif1 = snn.Leaky() self.fc2 = nn.Linear(4, 2) - self.sn2 = snntorch.Leaky() + self.lif2 = snn.Leaky() def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) - x_seq = self.sn1(x_seq) + x_seq = self.lif1(x_seq) x_seq = self.fc2(x_seq) - x_seq = self.sn2(x_seq) + x_seq = self.lif2(x_seq) return x_seq net = Net() - for param in net.parameters(): - #keeps all parameter in positive to make sure spike emiting in network. - param.data.abs_() - mtor = probe.AttributeMonitor('mem', False, net, instance=snntorch.Leaky()) + monitor = probe.AttributeMonitor('mem', False, net, instance=snn.Leaky()) with torch.no_grad(): y = net(torch.rand([1, 8])) - print(f'mtor.records={mtor.records}') - print(f'mtor[0]={mtor[0]}') - print(f'mtor.monitored_layers={mtor.monitored_layers}') - print(f"mtor['sn1']={mtor['sn1']}") + print(f'monitor.records={monitor.records}') + print(f'monitor[0]={monitor[0]}') + print(f'monitor.monitored_layers={monitor.monitored_layers}') + print(f"monitor['lif1']={monitor['lif1']}") - :param attribute_name: the attribute's name of probed neuron layer + :param attribute_name: Attribute's name of probed neuron layer (e.g., mem, syn, etc.) :type net: str - :param pre_forward: If ``True``, recording the attribute value before feed forward, otherwise recording the value after feed forward. + + :param pre_forward: If ``True``, record the attribute value before the forward pass, otherwise record the value after forward pass. :type pre_forward: bool - :param net: a PyTorch network + + :param net: Network model (either wrapped in Sequential container or as a class) :type net: nn.Module - :param instance: the instance of modules to be monitored. If ``None``, it will be regarded as ``type(net)`` + + :param instance: Instance of modules to be monitored. If ``None``, defaults to ``type(net)`` :type instance: Any or tuple - :param function_on_attribute: the function that applies on the monitored modules' inputs + + :param function_on_attribute: Function that is applied to the monitored modules' attribute :type function_on_attribute: Callable, optional ''' + def __init__(self, attribute_name: str, pre_forward: bool, net: nn.Module, instance: Any or tuple = None, function_on_attribute: Callable = lambda x: x): super().__init__() @@ -263,52 +271,54 @@ def hook(m, x, y): class GradInputMonitor(BaseMonitor): ''' - A monitor to record the input gradient of each specific neuron layer (e.g. Leaky) in a network. - all of the input data will be recorded in ''self.record'' with the python data type ''list''. + A monitor to record the input gradient of each neuron layer (e.g. Leaky) in a network. + All input gradient data is recorded in ``self.record`` as data type ''list''. Call ``self.enable()`` or ``self.disable()`` to enable or disable the monitor. - Call ``self.clear_recorded_data()`` to clear the recorded data. + Call ``self.clear_recorded_data()`` to clear recorded data. Example:: + import snntorch as snn from snntorch.functional import probe - import snntorch + import torch from torch import nn + class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(8, 4) - self.sn1 = snntorch.Leaky() + self.lif1 = snn.Leaky() self.fc2 = nn.Linear(4, 2) - self.sn2 = snntorch.Leaky() + self.lif2 = snn.Leaky() def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) - x_seq = self.sn1(x_seq) + x_seq = self.lif1(x_seq) x_seq = self.fc2(x_seq) - x_seq = self.sn2(x_seq) + x_seq = self.lif2(x_seq) return x_seq net = Net() - for param in net.parameters(): - #keeps all parameter in positive to make sure spike emiting in network. - param.data.abs_() - mtor = probe.GradInputMonitor(net, instance=snntorch.Leaky()) + monitor = probe.GradInputMonitor(net, instance=snn.Leaky()) with torch.no_grad(): y = net(torch.rand([1, 8])) - print(f'mtor.records={mtor.records}') - print(f'mtor[0]={mtor[0]}') - print(f'mtor.monitored_layers={mtor.monitored_layers}') - print(f"mtor['sn1']={mtor['sn1']}") - - :param net: a PyTorch network + print(f'monitor.records={monitor.records}') + print(f'monitor[0]={monitor[0]}') + print(f'monitor.monitored_layers={monitor.monitored_layers}') + print(f"monitor['lif1']={monitor['lif1']}") + + :param net: Network model (either wrapped in Sequential container or as a class) :type net: nn.Module - :param instance: the instance of modules to be monitored. If ``None``, it will be regarded as ``type(net)`` + + :param instance: Instance of modules to be monitored. If ``None``, defaults to ``type(net)`` :type instance: Any or tuple - :param function_on_grad_input: the function that applies on the monitored modules' inputs + + :param function_on_grad_input: Function that is applied to the monitored modules' gradients :type function_on_grad_input: Callable, optional ''' + def __init__(self, net: nn.Module, instance: Any or tuple = None, function_on_grad_input: Callable = lambda x: x): super().__init__() self.function_on_grad_input = function_on_grad_input @@ -333,51 +343,53 @@ def hook(m, grad_input, grad_output): class GradOutputMonitor(BaseMonitor): ''' A monitor to record the output gradient of each specific neuron layer (e.g. Leaky) in a network. - all of the input data will be recorded in ''self.record'' with the python data type ''list''. + All output gradient data is recorded in ``self.record`` as data type ''list''. Call ``self.enable()`` or ``self.disable()`` to enable or disable the monitor. - Call ``self.clear_recorded_data()`` to clear the recorded data. + Call ``self.clear_recorded_data()`` to clear recorded data. Example:: + import snntorch as snn from snntorch.functional import probe - import snntorch + import torch from torch import nn + class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(8, 4) - self.sn1 = snntorch.Leaky() + self.lif1 = snn.Leaky() self.fc2 = nn.Linear(4, 2) - self.sn2 = snntorch.Leaky() + self.lif2 = snn.Leaky() def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) - x_seq = self.sn1(x_seq) + x_seq = self.lif1(x_seq) x_seq = self.fc2(x_seq) - x_seq = self.sn2(x_seq) + x_seq = self.lif2(x_seq) return x_seq net = Net() - for param in net.parameters(): - #keeps all parameter in positive to make sure spike emiting in network. - param.data.abs_() - mtor = probe.GradOutputMonitor(net, instance=snntorch.Leaky()) + mtor = probe.GradOutputMonitor(net, instance=snn.Leaky()) with torch.no_grad(): y = net(torch.rand([1, 8])) print(f'mtor.records={mtor.records}') print(f'mtor[0]={mtor[0]}') print(f'mtor.monitored_layers={mtor.monitored_layers}') - print(f"mtor['sn1']={mtor['sn1']}") - - :param net: a PyTorch network + print(f"mtor['lif1']={mtor['lif1']}") + + :param net: Network model (either wrapped in Sequential container or as a class) :type net: nn.Module - :param instance: the instance of modules to be monitored. If ``None``, it will be regarded as ``type(net)`` + + :param instance: Instance of modules to be monitored. If ``None``, defaults to ``type(net)`` :type instance: Any or tuple - :param function_on_grad_output: the function that applies on the monitored modules' inputs + + :param function_on_grad_output: Function that is applied to the monitored modules' gradients :type function_on_grad_output: Callable, optional ''' + def __init__(self, net: nn.Module, instance: Any or tuple = None, function_on_grad_output: Callable = lambda x: x): super().__init__() self.function_on_grad_output = function_on_grad_output