Skip to content

Commit

Permalink
Add LeakyReLU neuron models
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Jan 23, 2020
1 parent 9256848 commit 6210ee4
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Release history
Layers/parameters that cannot be fully converted to native Nengo objects to be
converted in a way that only matches the inference behaviour of the source Keras model
(not the training behaviour). (`#119`_)
- Added ``nengo_dl.LeakyReLU`` and ``nengo_dl.SpikingLeakyReLU`` neuron models.
(`#126`_)

**Changed**

Expand Down Expand Up @@ -60,6 +62,7 @@ Release history
.. _#119: https://github.com/nengo/nengo-dl/pull/119
.. _#1591: https://github.com/nengo/nengo/pull/1591
.. _#128: https://github.com/nengo/nengo-dl/pull/128
.. _#126: https://github.com/nengo/nengo-dl/pull/126

3.0.0 (December 17, 2019)
-------------------------
Expand Down
6 changes: 5 additions & 1 deletion nengo_dl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@
from nengo_dl import callbacks, compat, converter, dists, losses
from nengo_dl.config import configure_settings
from nengo_dl.converter import Converter
from nengo_dl.neurons import SoftLIFRate
from nengo_dl.neurons import (
LeakyReLU,
SpikingLeakyReLU,
SoftLIFRate,
)
from nengo_dl.simulator import Simulator
from nengo_dl.tensor_node import TensorNode, Layer, tensor_layer
25 changes: 22 additions & 3 deletions nengo_dl/neuron_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from nengo_dl import utils
from nengo_dl.builder import Builder, OpBuilder
from nengo_dl.neurons import SoftLIFRate
from nengo_dl.neurons import LeakyReLU, SoftLIFRate, SpikingLeakyReLU

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -132,8 +132,20 @@ def __init__(self, ops, signals, config):
signals.dtype,
)

if all(getattr(op.neurons, "negative_slope", 0) == 0 for op in ops):
self.negative_slope = None
else:
self.negative_slope = signals.op_constant(
[op.neurons for op in ops],
[op.J.shape[0] for op in ops],
"negative_slope",
signals.dtype,
)

def _step(self, J):
out = tf.nn.relu(J)
if self.negative_slope is not None:
out -= self.negative_slope * tf.nn.relu(-J)
if self.amplitude is not None:
out *= self.amplitude
return out
Expand All @@ -157,8 +169,13 @@ def __init__(self, ops, signals, config):
self.alpha /= signals.dt

def _step(self, J, voltage, dt):
voltage += tf.nn.relu(J) * dt
n_spikes = tf.floor(voltage)
if self.negative_slope is None:
voltage += tf.nn.relu(J) * dt
n_spikes = tf.floor(voltage)
else:
voltage += (tf.nn.relu(J) - self.negative_slope * tf.nn.relu(-J)) * dt
n_spikes = tf.floor(voltage) + tf.cast(voltage < 0, voltage.dtype)

voltage -= n_spikes
out = n_spikes * self.alpha

Expand Down Expand Up @@ -406,6 +423,8 @@ class SimNeuronsBuilder(OpBuilder):
TF_NEURON_IMPL = {
RectifiedLinear: RectifiedLinearBuilder,
SpikingRectifiedLinear: SpikingRectifiedLinearBuilder,
LeakyReLU: RectifiedLinearBuilder,
SpikingLeakyReLU: SpikingRectifiedLinearBuilder,
Sigmoid: SigmoidBuilder,
LIF: LIFBuilder,
LIFRate: LIFRateBuilder,
Expand Down
65 changes: 64 additions & 1 deletion nengo_dl/neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Additions to the `neuron types included with Nengo <nengo.neurons.NeuronType>`.
"""

from nengo.neurons import LIFRate
from nengo.neurons import LIFRate, RectifiedLinear, SpikingRectifiedLinear
from nengo.params import NumberParam
import numpy as np

Expand Down Expand Up @@ -76,3 +76,66 @@ def step_math(self, dt, J, output):

q = np.where(j_valid, np.log1p(1 / z), -js - np.log(self.sigma))
output[:] = self.amplitude / (self.tau_ref + self.tau_rc * q)


class LeakyReLU(RectifiedLinear):
"""
Rectified linear neuron with nonzero slope for values < 0.
Parameters
----------
negative_slope : float
Scaling factor applied to values less than zero.
amplitude : float
Scaling factor on the neuron output. Note that this will combine
multiplicatively with ``negative_slope`` for values < 0.
"""

def __init__(self, negative_slope=0.3, amplitude=1):
super().__init__(amplitude=amplitude)

self.negative_slope = negative_slope

def step_math(self, dt, J, output):
"""Implement the leaky relu nonlinearity."""

output[...] = self.amplitude * np.where(J < 0, self.negative_slope * J, J)


class SpikingLeakyReLU(SpikingRectifiedLinear):
"""
Spiking version of `.LeakyReLU`.
Note that this may output "negative spikes" (i.e. a spike with a sign of -1).
Parameters
----------
negative_slope : float
Scaling factor applied to values less than zero.
amplitude : float
Scaling factor on the neuron output. Note that this will combine
multiplicatively with ``negative_slope`` for values < 0.
"""

def __init__(self, negative_slope=0.3, amplitude=1):
super().__init__(amplitude=amplitude)

self.negative_slope = negative_slope

def rates(self, x, gain, bias):
"""Use `.LeakyReLU` to determine rates."""

J = self.current(x, gain, bias)
out = np.zeros_like(J)
LeakyReLU.step_math(self, dt=1, J=J, output=out)
return out

def step_math(self, dt, J, spiked, voltage):
"""
Implement the spiking leaky relu nonlinearity.
"""

voltage += np.where(J < 0, self.negative_slope * J, J) * dt
n_spikes = np.trunc(voltage)
spiked[:] = (self.amplitude / dt) * n_spikes
voltage -= n_spikes
81 changes: 80 additions & 1 deletion nengo_dl/tests/test_neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
import numpy as np
import pytest

from nengo_dl import config, dists, SoftLIFRate, neuron_builders
from nengo_dl import (
LeakyReLU,
SpikingLeakyReLU,
SoftLIFRate,
config,
dists,
neuron_builders,
)


def test_lif_deterministic(Simulator, seed):
Expand Down Expand Up @@ -176,3 +183,75 @@ def test_spiking_swap(Simulator, rate, spiking, seed):

# check that the gradients match
assert all(np.allclose(g0, g1) for g0, g1 in zip(*grads))


@pytest.mark.parametrize("Neurons", (LeakyReLU, SpikingLeakyReLU))
def test_leaky_relu(Simulator, Neurons):
assert np.allclose(Neurons(negative_slope=0.1).rates([-2, 2], 1, 0), [[-0.2], [2]])

assert np.allclose(
Neurons(negative_slope=0.1, amplitude=0.1).rates([-2, 2], 1, 0),
[[-0.02], [0.2]],
)

with nengo.Network() as net:
vals = np.linspace(-400, 400, 10)
ens0 = nengo.Ensemble(
10,
1,
neuron_type=Neurons(negative_slope=0.1, amplitude=2),
gain=nengo.dists.Choice([1]),
bias=vals,
)
ens1 = nengo.Ensemble(
10,
1,
neuron_type=Neurons(negative_slope=0.5),
gain=nengo.dists.Choice([1]),
bias=vals,
)
p0 = nengo.Probe(ens0.neurons)
p1 = nengo.Probe(ens1.neurons)

with Simulator(net) as sim:
# make sure that ops have been merged
assert (
len(
[
ops
for ops in sim.tensor_graph.plan
if isinstance(ops[0], nengo.builder.neurons.SimNeurons)
]
)
== 1
)

sim.run(1.0)

assert np.allclose(
np.sum(sim.data[p0], axis=0) * sim.dt,
np.where(vals < 0, vals * 0.1 * 2, vals * 2),
atol=1,
)

assert np.allclose(
np.sum(sim.data[p1], axis=0) * sim.dt,
np.where(vals < 0, vals * 0.5, vals),
atol=1,
)

# check that it works in the regular nengo simulator as well
with nengo.Simulator(net) as sim:
sim.run(1.0)

assert np.allclose(
np.sum(sim.data[p0], axis=0) * sim.dt,
np.where(vals < 0, vals * 0.1 * 2, vals * 2),
atol=1,
)

assert np.allclose(
np.sum(sim.data[p1], axis=0) * sim.dt,
np.where(vals < 0, vals * 0.5, vals),
atol=1,
)

0 comments on commit 6210ee4

Please sign in to comment.