From 6210ee47a15519fd57c056c3017dc985c95313ac Mon Sep 17 00:00:00 2001 From: Daniel Rasmussen Date: Mon, 20 Jan 2020 13:03:07 -0400 Subject: [PATCH] Add LeakyReLU neuron models --- CHANGES.rst | 3 ++ nengo_dl/__init__.py | 6 ++- nengo_dl/neuron_builders.py | 25 +++++++++-- nengo_dl/neurons.py | 65 ++++++++++++++++++++++++++- nengo_dl/tests/test_neurons.py | 81 +++++++++++++++++++++++++++++++++- 5 files changed, 174 insertions(+), 6 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index b2f867d6c..80d975089 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -27,6 +27,8 @@ Release history Layers/parameters that cannot be fully converted to native Nengo objects to be converted in a way that only matches the inference behaviour of the source Keras model (not the training behaviour). (`#119`_) +- Added ``nengo_dl.LeakyReLU`` and ``nengo_dl.SpikingLeakyReLU`` neuron models. + (`#126`_) **Changed** @@ -60,6 +62,7 @@ Release history .. _#119: https://github.com/nengo/nengo-dl/pull/119 .. _#1591: https://github.com/nengo/nengo/pull/1591 .. _#128: https://github.com/nengo/nengo-dl/pull/128 +.. _#126: https://github.com/nengo/nengo-dl/pull/126 3.0.0 (December 17, 2019) ------------------------- diff --git a/nengo_dl/__init__.py b/nengo_dl/__init__.py index 5f738e210..0b5ab4ee7 100644 --- a/nengo_dl/__init__.py +++ b/nengo_dl/__init__.py @@ -58,6 +58,10 @@ from nengo_dl import callbacks, compat, converter, dists, losses from nengo_dl.config import configure_settings from nengo_dl.converter import Converter -from nengo_dl.neurons import SoftLIFRate +from nengo_dl.neurons import ( + LeakyReLU, + SpikingLeakyReLU, + SoftLIFRate, +) from nengo_dl.simulator import Simulator from nengo_dl.tensor_node import TensorNode, Layer, tensor_layer diff --git a/nengo_dl/neuron_builders.py b/nengo_dl/neuron_builders.py index 49d0d71e3..8a67b3233 100644 --- a/nengo_dl/neuron_builders.py +++ b/nengo_dl/neuron_builders.py @@ -12,7 +12,7 @@ from nengo_dl import utils from nengo_dl.builder import Builder, OpBuilder -from nengo_dl.neurons import SoftLIFRate +from nengo_dl.neurons import LeakyReLU, SoftLIFRate, SpikingLeakyReLU logger = logging.getLogger(__name__) @@ -132,8 +132,20 @@ def __init__(self, ops, signals, config): signals.dtype, ) + if all(getattr(op.neurons, "negative_slope", 0) == 0 for op in ops): + self.negative_slope = None + else: + self.negative_slope = signals.op_constant( + [op.neurons for op in ops], + [op.J.shape[0] for op in ops], + "negative_slope", + signals.dtype, + ) + def _step(self, J): out = tf.nn.relu(J) + if self.negative_slope is not None: + out -= self.negative_slope * tf.nn.relu(-J) if self.amplitude is not None: out *= self.amplitude return out @@ -157,8 +169,13 @@ def __init__(self, ops, signals, config): self.alpha /= signals.dt def _step(self, J, voltage, dt): - voltage += tf.nn.relu(J) * dt - n_spikes = tf.floor(voltage) + if self.negative_slope is None: + voltage += tf.nn.relu(J) * dt + n_spikes = tf.floor(voltage) + else: + voltage += (tf.nn.relu(J) - self.negative_slope * tf.nn.relu(-J)) * dt + n_spikes = tf.floor(voltage) + tf.cast(voltage < 0, voltage.dtype) + voltage -= n_spikes out = n_spikes * self.alpha @@ -406,6 +423,8 @@ class SimNeuronsBuilder(OpBuilder): TF_NEURON_IMPL = { RectifiedLinear: RectifiedLinearBuilder, SpikingRectifiedLinear: SpikingRectifiedLinearBuilder, + LeakyReLU: RectifiedLinearBuilder, + SpikingLeakyReLU: SpikingRectifiedLinearBuilder, Sigmoid: SigmoidBuilder, LIF: LIFBuilder, LIFRate: LIFRateBuilder, diff --git a/nengo_dl/neurons.py b/nengo_dl/neurons.py index e16cea866..db1fefaf6 100644 --- a/nengo_dl/neurons.py +++ b/nengo_dl/neurons.py @@ -2,7 +2,7 @@ Additions to the `neuron types included with Nengo `. """ -from nengo.neurons import LIFRate +from nengo.neurons import LIFRate, RectifiedLinear, SpikingRectifiedLinear from nengo.params import NumberParam import numpy as np @@ -76,3 +76,66 @@ def step_math(self, dt, J, output): q = np.where(j_valid, np.log1p(1 / z), -js - np.log(self.sigma)) output[:] = self.amplitude / (self.tau_ref + self.tau_rc * q) + + +class LeakyReLU(RectifiedLinear): + """ + Rectified linear neuron with nonzero slope for values < 0. + + Parameters + ---------- + negative_slope : float + Scaling factor applied to values less than zero. + amplitude : float + Scaling factor on the neuron output. Note that this will combine + multiplicatively with ``negative_slope`` for values < 0. + """ + + def __init__(self, negative_slope=0.3, amplitude=1): + super().__init__(amplitude=amplitude) + + self.negative_slope = negative_slope + + def step_math(self, dt, J, output): + """Implement the leaky relu nonlinearity.""" + + output[...] = self.amplitude * np.where(J < 0, self.negative_slope * J, J) + + +class SpikingLeakyReLU(SpikingRectifiedLinear): + """ + Spiking version of `.LeakyReLU`. + + Note that this may output "negative spikes" (i.e. a spike with a sign of -1). + + Parameters + ---------- + negative_slope : float + Scaling factor applied to values less than zero. + amplitude : float + Scaling factor on the neuron output. Note that this will combine + multiplicatively with ``negative_slope`` for values < 0. + """ + + def __init__(self, negative_slope=0.3, amplitude=1): + super().__init__(amplitude=amplitude) + + self.negative_slope = negative_slope + + def rates(self, x, gain, bias): + """Use `.LeakyReLU` to determine rates.""" + + J = self.current(x, gain, bias) + out = np.zeros_like(J) + LeakyReLU.step_math(self, dt=1, J=J, output=out) + return out + + def step_math(self, dt, J, spiked, voltage): + """ + Implement the spiking leaky relu nonlinearity. + """ + + voltage += np.where(J < 0, self.negative_slope * J, J) * dt + n_spikes = np.trunc(voltage) + spiked[:] = (self.amplitude / dt) * n_spikes + voltage -= n_spikes diff --git a/nengo_dl/tests/test_neurons.py b/nengo_dl/tests/test_neurons.py index 3d425238c..5392f6874 100644 --- a/nengo_dl/tests/test_neurons.py +++ b/nengo_dl/tests/test_neurons.py @@ -4,7 +4,14 @@ import numpy as np import pytest -from nengo_dl import config, dists, SoftLIFRate, neuron_builders +from nengo_dl import ( + LeakyReLU, + SpikingLeakyReLU, + SoftLIFRate, + config, + dists, + neuron_builders, +) def test_lif_deterministic(Simulator, seed): @@ -176,3 +183,75 @@ def test_spiking_swap(Simulator, rate, spiking, seed): # check that the gradients match assert all(np.allclose(g0, g1) for g0, g1 in zip(*grads)) + + +@pytest.mark.parametrize("Neurons", (LeakyReLU, SpikingLeakyReLU)) +def test_leaky_relu(Simulator, Neurons): + assert np.allclose(Neurons(negative_slope=0.1).rates([-2, 2], 1, 0), [[-0.2], [2]]) + + assert np.allclose( + Neurons(negative_slope=0.1, amplitude=0.1).rates([-2, 2], 1, 0), + [[-0.02], [0.2]], + ) + + with nengo.Network() as net: + vals = np.linspace(-400, 400, 10) + ens0 = nengo.Ensemble( + 10, + 1, + neuron_type=Neurons(negative_slope=0.1, amplitude=2), + gain=nengo.dists.Choice([1]), + bias=vals, + ) + ens1 = nengo.Ensemble( + 10, + 1, + neuron_type=Neurons(negative_slope=0.5), + gain=nengo.dists.Choice([1]), + bias=vals, + ) + p0 = nengo.Probe(ens0.neurons) + p1 = nengo.Probe(ens1.neurons) + + with Simulator(net) as sim: + # make sure that ops have been merged + assert ( + len( + [ + ops + for ops in sim.tensor_graph.plan + if isinstance(ops[0], nengo.builder.neurons.SimNeurons) + ] + ) + == 1 + ) + + sim.run(1.0) + + assert np.allclose( + np.sum(sim.data[p0], axis=0) * sim.dt, + np.where(vals < 0, vals * 0.1 * 2, vals * 2), + atol=1, + ) + + assert np.allclose( + np.sum(sim.data[p1], axis=0) * sim.dt, + np.where(vals < 0, vals * 0.5, vals), + atol=1, + ) + + # check that it works in the regular nengo simulator as well + with nengo.Simulator(net) as sim: + sim.run(1.0) + + assert np.allclose( + np.sum(sim.data[p0], axis=0) * sim.dt, + np.where(vals < 0, vals * 0.1 * 2, vals * 2), + atol=1, + ) + + assert np.allclose( + np.sum(sim.data[p1], axis=0) * sim.dt, + np.where(vals < 0, vals * 0.5, vals), + atol=1, + )