Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for TensorFlow 2.4 #185

Merged
merged 4 commits into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .nengobones.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ setup_cfg:
tests/test_probe.py::test_multirun:
simulation times may not line up exactly if unroll_simulation != 1, see
tests/test_nengo_tests.py::test_multirun
tests/test_learning_rules.py::test_rls_*:
RLS learning rule not implemented
allclose_tolerances:
- tests/test_synapses.py::test_lowpass atol=5e-7
- tests/test_synapses.py::test_triangle atol=5e-7
Expand Down
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Release history
- Added support for `KerasSpiking <https://www.nengo.ai/keras-spiking/>`_ layers in
the Converter. (`#182`_)
- Added support for ``tf.keras.layers.TimeDistributed`` in the Converter. (`#182`_)
- Added support for TensorFlow 2.4. (`#185`_)

**Changed**

Expand All @@ -47,6 +48,7 @@ Release history
.. _#173: https://github.com/nengo/nengo-dl/pull/173
.. _#181: https://github.com/nengo/nengo-dl/pull/181
.. _#182: https://github.com/nengo/nengo-dl/pull/182
.. _#185: https://github.com/nengo/nengo-dl/pull/185

3.3.0 (August 14, 2020)
-----------------------
Expand Down
4 changes: 1 addition & 3 deletions nengo_dl/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,7 @@ def _get_key(self, key):
# get output tensor
key = key.output

key = tuple(
x.ref() if isinstance(x, tf.Tensor) else x for x in tf.nest.flatten(key)
)
key = tuple(x.ref() if tf.is_tensor(x) else x for x in tf.nest.flatten(key))

return key

Expand Down
4 changes: 2 additions & 2 deletions nengo_dl/neuron_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from nengo.neurons import RectifiedLinear, SpikingRectifiedLinear, Sigmoid, LIF, LIFRate
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.framework import smart_cond

from nengo_dl import compat, utils
from nengo_dl.builder import Builder, OpBuilder
Expand Down Expand Up @@ -188,7 +188,7 @@ def build_step(self, signals, **step_kwargs):
out = step_output
else:
out = tf.nest.flatten(
tf_utils.smart_cond(
smart_cond.smart_cond(
self.config.training,
true_fn=lambda: (self.training_step(J, signals.dt, **state),)
+ tuple(state.values()),
Expand Down
4 changes: 3 additions & 1 deletion nengo_dl/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,7 +1815,9 @@ def _generate_inputs(self, data=None, n_steps=None):
if data is None:
data = {}

if not isinstance(data, (list, tuple, dict, np.ndarray, tf.Tensor)):
if not isinstance(data, (list, tuple, dict, np.ndarray)) and not tf.is_tensor(
data
):
# data is some kind of generator, so we don't try to modify it (too many
# different types of generators this could be)
if n_steps is not None:
Expand Down
8 changes: 8 additions & 0 deletions nengo_dl/tensor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ def __init__(
)
]

# check for unsupported operators
for op in operators:
if type(op) not in builder.Builder.builders:
raise BuildError(
"No registered builder for operators of type %s; "
"consider registering a custom builder" % type(op)
)

# mark trainable signals
self.mark_signals()

Expand Down
2 changes: 1 addition & 1 deletion nengo_dl/tensor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def validate_output(output, minibatch_size=None, output_d=None, dtype=None):
Expected dtype of the function output.
"""

if not isinstance(output, (tf.Tensor, tf.TensorSpec)):
if not isinstance(output, tf.TensorSpec) and not tf.is_tensor(output):
raise ValidationError(
"TensorNode function must return a Tensor (got %s)" % type(output),
attr="tensor_func",
Expand Down
11 changes: 10 additions & 1 deletion nengo_dl/tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,16 @@ def test_tensorboard(Simulator, tmpdir):

@pytest.mark.parametrize("mode", ("predict", "train"))
@pytest.mark.training
def test_profile(Simulator, mode, tmpdir):
def test_profile(Simulator, mode, tmpdir, pytestconfig):
if (
pytestconfig.getoption("--graph-mode")
and version.parse(tf.__version__) >= version.parse("2.4.0rc0")
and mode == "predict"
):
pytest.skip(
"TensorFlow bug, see https://github.com/tensorflow/tensorflow/issues/44563"
)

net, a, p = dummies.linear_net()

with Simulator(net) as sim:
Expand Down
11 changes: 11 additions & 0 deletions nengo_dl/tests/test_tensor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import nengo
from nengo.builder.operator import Reset
from nengo.builder.signal import Signal
from nengo.exceptions import BuildError
import numpy as np
import pytest
import tensorflow as tf
Expand Down Expand Up @@ -547,3 +548,13 @@ def test_conditional_update(Simulator, use_loop, caplog):
pass

assert "Number of state updates: 1" in caplog.text


def test_unsupported_op_error():
class MyOp(dummies.Op): # pylint: disable=abstract-method
pass

model = nengo.builder.Model()
model.add_op(MyOp())
with pytest.raises(BuildError, match="No registered builder"):
tensor_graph.TensorGraph(model, None, None, None, None, None, None)
2 changes: 1 addition & 1 deletion nengo_dl/tests/test_tensor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def __init__(self, expected):
self.expected = expected

def call(self, inputs, training=None):
tf.assert_equal(training, self.expected)
tf.assert_equal(tf.cast(training, tf.bool), self.expected)
return tf.reshape(inputs, (1, 1))

with nengo.Network() as net:
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ nengo_test_unsupported =
tests/test_probe.py::test_multirun
"simulation times may not line up exactly if unroll_simulation != 1, see
tests/test_nengo_tests.py::test_multirun"
tests/test_learning_rules.py::test_rls_*
"RLS learning rule not implemented"
allclose_tolerances =
tests/test_synapses.py::test_lowpass atol=5e-7
tests/test_synapses.py::test_triangle atol=5e-7
Expand Down