Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Sparse transforms in Simulator.get_nengo_params #149

Merged
merged 4 commits into from
Apr 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .nengobones.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ setup_cfg:
pylint:
known_third_party:
- PIL
- packaging
- progressbar
- tensorflow
coverage:
Expand Down
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ Release history
3.2.1 (unreleased)
------------------

**Fixed**

- Support Sparse transforms in ``Simulator.get_nengo_params``. (`#149`_)

.. _#149: https://github.com/nengo/nengo-dl/pull/149

3.2.0 (April 2, 2020)
---------------------
Expand Down
21 changes: 18 additions & 3 deletions nengo_dl/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
dependencies.
"""

from distutils.version import LooseVersion

import nengo
from nengo._vendor.scipy.sparse import linalg_interface, linalg_onenormest
from packaging import version
import tensorflow as tf
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import network
Expand Down Expand Up @@ -76,7 +76,7 @@ def filter(self, record):

tf.get_logger().addFilter(TFLogFilter(err_on_deprecation=False))

if LooseVersion(tf.__version__) < "2.2.0":
if version.parse(tf.__version__) < version.parse("2.2.0rc0"):

def global_learning_phase():
"""Returns the global (eager) Keras learning phase."""
Expand Down Expand Up @@ -117,12 +117,25 @@ def _conform_to_reference_input(self, tensor, ref_input):

network.Network._conform_to_reference_input = _conform_to_reference_input

if version.parse(tf.__version__) < version.parse("2.1.0rc0"):
hunse marked this conversation as resolved.
Show resolved Hide resolved
from tensorflow.python.keras.layers import (
BatchNormalization as BatchNormalizationV1,
)
from tensorflow.python.keras.layers import BatchNormalizationV2
else:
from tensorflow.python.keras.layers import (
BatchNormalizationV1,
BatchNormalizationV2,
)

# Nengo compatibility

# monkeypatch fix for https://github.com/nengo/nengo/pull/1587
linalg_onenormest.aslinearoperator = linalg_interface.aslinearoperator

if LooseVersion(nengo.__version__) < "3.1.0":
if version.parse(nengo.__version__) < version.parse("3.1.0.dev0"):
NoTransform = type(None)

default_transform = 1

def conn_has_weights(conn):
Expand All @@ -131,6 +144,8 @@ def conn_has_weights(conn):


else:
from nengo.transforms import NoTransform

default_transform = None

def conn_has_weights(conn):
Expand Down
10 changes: 6 additions & 4 deletions nengo_dl/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import nengo
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.layers import BatchNormalization, BatchNormalizationV2
from tensorflow.python.util import nest

from nengo_dl import compat
Expand Down Expand Up @@ -1170,8 +1169,8 @@ def convert(self, node_id):
return super().convert(node_id, dimensions=3)


@Converter.register(BatchNormalization)
@Converter.register(BatchNormalizationV2)
@Converter.register(compat.BatchNormalizationV1)
@Converter.register(compat.BatchNormalizationV2)
class ConvertBatchNormalization(LayerConverter):
"""Convert ``tf.keras.layers.BatchNormalization`` to Nengo objects."""

Expand Down Expand Up @@ -1308,7 +1307,10 @@ def convert(self, node_id, dimensions):

# add trainable bias weights
bias_node = nengo.Node([1], label="%s.%d.bias" % (self.layer.name, node_id))
bias_relay = nengo.Node(size_in=len(biases))
bias_relay = nengo.Node(
size_in=len(biases),
label="%s.%d.bias_relay" % (self.layer.name, node_id),
)
nengo.Connection(
bias_node, bias_relay, transform=biases[:, None], synapse=None
)
Expand Down
2 changes: 1 addition & 1 deletion nengo_dl/graph_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,7 @@ def is_identity(x, sig):

return (
x.shape == (d, d)
and np.allclose(np.diag(x), 1)
and np.allclose(x.diagonal(), 1)
and np.allclose(np.sum(np.abs(x)), d) # all non-diagonal elements are 0
)

Expand Down
64 changes: 49 additions & 15 deletions nengo_dl/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
import warnings

import jinja2
from nengo import Connection, Direct, Ensemble, Network, Node, Probe, Convolution
from nengo import (
Connection,
Direct,
Ensemble,
Network,
Node,
Probe,
)
from nengo import rc as nengo_rc
from nengo.builder.connection import BuiltConnection
from nengo.builder.ensemble import BuiltEnsemble
Expand All @@ -25,6 +32,7 @@
ValidationError,
)
from nengo.solvers import NoSolver
from nengo.transforms import Convolution, Dense, Sparse, SparseMatrix
from nengo.utils.magic import decorator
import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -1358,6 +1366,12 @@ def get_nengo_params(self, nengo_objs, as_dict=False):
e = nengo.Ensemble(10, 1, **params[1])
f = nengo.Connection(d, e, **params[2])

Note that this function only returns trainable parameters (e.g. connection
weights, biases, or encoders), or parameters that directly interact with
those parameters (e.g. gains). Other arguments that are independent of the
trainable parameters (e.g. ``Ensemble.neuron_type`` or ``Connection.synapse``)
should be specified manually (since they may change between models).

Parameters
----------
nengo_objs : (list of) `~nengo.Ensemble` or `~nengo.Connection`
Expand Down Expand Up @@ -1427,27 +1441,47 @@ def get_nengo_params(self, nengo_objs, as_dict=False):

weights = data[idx]
idx += 1
if isinstance(obj.pre_obj, Ensemble):
params.append(
{
"solver": NoSolver(weights.T, weights=False),
"function": lambda x, weights=weights: np.zeros(
weights.shape[0]
),
"transform": compat.default_transform,
}
)
elif isinstance(obj.transform, Convolution):
if isinstance(obj.transform, Convolution):
transform = copy.copy(obj.transform)
# manually bypass the read-only check (we are sure that
# nothing else has a handle to the new transform at this
# point, so this won't cause any problems)
Convolution.init.data[transform] = weights
params.append({"transform": transform})
elif isinstance(obj.transform, Sparse):
transform = copy.copy(obj.transform)
if isinstance(transform.init, SparseMatrix):
init = SparseMatrix(
transform.init.indices, weights, transform.init.shape
)
else:
init = transform.init.tocoo()
init = SparseMatrix(
np.stack((init.row, init.col), axis=-1), weights, init.shape
)
Sparse.init.data[transform] = init
params.append({"transform": transform})
elif isinstance(obj.transform, (Dense, compat.NoTransform)):
if isinstance(obj.pre_obj, Ensemble):
# decoded connection
params.append(
{
"solver": NoSolver(weights.T, weights=False),
"function": lambda x, weights=weights: np.zeros(
weights.shape[0]
),
"transform": compat.default_transform,
}
)
else:
if all(x == 1 for x in weights.shape):
weights = np.squeeze(weights)
params.append({"transform": weights})
else:
if all(x == 1 for x in weights.shape):
weights = np.squeeze(weights)
params.append({"transform": weights})
raise NotImplementedError(
"Cannot get parameters of Connections with transform type '%s'"
% type(obj.transform).__name__
)
else:
# note: we don't want to change the original gain (even though
# it is rolled into the encoder values), because connections
Expand Down
13 changes: 9 additions & 4 deletions nengo_dl/tests/test_converter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# pylint: disable=missing-docstring

from distutils.version import LooseVersion

import nengo
import numpy as np
from packaging import version
import pytest
import tensorflow as tf
from tensorflow.python.keras.layers import BatchNormalization
Expand Down Expand Up @@ -175,12 +174,18 @@ def test_batch_normalization(rng):
# TF<2.1 doesn't support axis!=-1 for fused batchnorm
out.append(
tf.keras.layers.BatchNormalization(
axis=1, fused=None if LooseVersion(tf.__version__) >= "2.1.0" else False
axis=1,
fused=False
if version.parse(tf.__version__) < version.parse("2.1.0rc0")
else None,
)(inp)
)
out.append(
tf.keras.layers.BatchNormalization(
axis=2, fused=None if LooseVersion(tf.__version__) >= "2.1.0" else False
axis=2,
fused=False
if version.parse(tf.__version__) < version.parse("2.1.0rc0")
else None,
)(inp)
)
out.append(tf.keras.layers.BatchNormalization()(inp))
Expand Down
5 changes: 2 additions & 3 deletions nengo_dl/tests/test_graph_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# pylint: disable=missing-docstring

from distutils.version import LooseVersion

import nengo
from nengo.exceptions import BuildError
from nengo.neurons import LIF, LIFRate, Izhikevich, AdaptiveLIF
Expand All @@ -20,6 +18,7 @@
from nengo.builder.signal import Signal
from nengo.builder.transforms import ConvInc
import numpy as np
from packaging import version
import pytest

from nengo_dl import config, op_builders, transform_builders
Expand Down Expand Up @@ -1183,7 +1182,7 @@ def test_remove_reset_inc_functional(Simulator, seed):
p = nengo.Probe(node1)

with Simulator(net) as sim:
extra_op = LooseVersion(nengo.__version__) < "3.1.0"
extra_op = version.parse(nengo.__version__) < version.parse("3.1.0.dev0")

assert len(sim.tensor_graph.plan) == 8 + extra_op

Expand Down
10 changes: 5 additions & 5 deletions nengo_dl/tests/test_learning_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def test_merged_learning(Simulator, rule, weights, seed):
assert np.allclose(sim.data[p1][i], canonical[1])


def test_online_learning_reset(Simulator, tmpdir):
with nengo.Network() as net:
def test_online_learning_reset(Simulator, tmpdir, seed):
with nengo.Network(seed=seed) as net:
inp = nengo.Ensemble(10, 1)
out = nengo.Node(size_in=1)
conn = nengo.Connection(inp, out, learning_rule_type=nengo.PES(1))
Expand All @@ -98,12 +98,12 @@ def test_online_learning_reset(Simulator, tmpdir):
# test that learning has changed weights
assert not np.allclose(w0, w1)

# test that soft reset does NOT reset the online learning weights
sim.soft_reset()
# test that include_trainable=False does NOT reset the online learning weights
sim.reset(include_trainable=False)
assert np.allclose(w1, sim.data[conn].weights)

# test that full reset DOES reset the online learning weights
sim.reset()
sim.reset(include_trainable=True)
assert np.allclose(w0, sim.data[conn].weights)

# test that weights load correctly
Expand Down
6 changes: 0 additions & 6 deletions nengo_dl/tests/test_nengo_tests.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# pylint: disable=missing-docstring

from distutils.version import LooseVersion

import nengo
from nengo.builder.signal import Signal
from nengo.builder.operator import ElementwiseInc, DotInc
import numpy as np
import pkg_resources
import pytest
import tensorflow as tf

import nengo_dl
from nengo_dl.tests import dummies
Expand Down Expand Up @@ -77,9 +74,6 @@ def test_signal_init_values(Simulator):


def test_entry_point():
if LooseVersion(tf.__version__) == "1.11.0":
pytest.xfail("TensorFlow 1.11.0 has conflicting dependencies")

sims = [
ep.load(require=False)
for ep in pkg_resources.iter_entry_points(group="nengo.backends")
Expand Down
Loading