From 1ccfbeeae2fda6afbd942ea01f11635ef56914a7 Mon Sep 17 00:00:00 2001 From: Daniel Rasmussen Date: Mon, 2 Nov 2020 20:37:35 -0400 Subject: [PATCH] Support for TensorFlow 2.4 --- CHANGES.rst | 2 ++ nengo_dl/converter.py | 4 +--- nengo_dl/neuron_builders.py | 4 ++-- nengo_dl/simulator.py | 4 +++- nengo_dl/tensor_node.py | 2 +- nengo_dl/tests/test_simulator.py | 11 ++++++++++- nengo_dl/tests/test_tensor_node.py | 2 +- 7 files changed, 20 insertions(+), 9 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index ed22d7998..cfa4cff88 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -26,6 +26,7 @@ Release history - Added support for `KerasSpiking `_ layers in the Converter. (`#182`_) - Added support for ``tf.keras.layers.TimeDistributed`` in the Converter. (`#182`_) +- Added support for TensorFlow 2.4. (`#185`_) **Changed** @@ -47,6 +48,7 @@ Release history .. _#173: https://github.com/nengo/nengo-dl/pull/173 .. _#181: https://github.com/nengo/nengo-dl/pull/181 .. _#182: https://github.com/nengo/nengo-dl/pull/182 +.. _#185: https://github.com/nengo/nengo-dl/pull/185 3.3.0 (August 14, 2020) ----------------------- diff --git a/nengo_dl/converter.py b/nengo_dl/converter.py index a9c760698..eb520e035 100644 --- a/nengo_dl/converter.py +++ b/nengo_dl/converter.py @@ -391,9 +391,7 @@ def _get_key(self, key): # get output tensor key = key.output - key = tuple( - x.ref() if isinstance(x, tf.Tensor) else x for x in tf.nest.flatten(key) - ) + key = tuple(x.ref() if tf.is_tensor(x) else x for x in tf.nest.flatten(key)) return key diff --git a/nengo_dl/neuron_builders.py b/nengo_dl/neuron_builders.py index d58e77094..a653674d1 100644 --- a/nengo_dl/neuron_builders.py +++ b/nengo_dl/neuron_builders.py @@ -11,7 +11,7 @@ from nengo.neurons import RectifiedLinear, SpikingRectifiedLinear, Sigmoid, LIF, LIFRate import numpy as np import tensorflow as tf -from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.framework import smart_cond from nengo_dl import compat, utils from nengo_dl.builder import Builder, OpBuilder @@ -188,7 +188,7 @@ def build_step(self, signals, **step_kwargs): out = step_output else: out = tf.nest.flatten( - tf_utils.smart_cond( + smart_cond.smart_cond( self.config.training, true_fn=lambda: (self.training_step(J, signals.dt, **state),) + tuple(state.values()), diff --git a/nengo_dl/simulator.py b/nengo_dl/simulator.py index 170a1eede..3dfa2ac90 100644 --- a/nengo_dl/simulator.py +++ b/nengo_dl/simulator.py @@ -1815,7 +1815,9 @@ def _generate_inputs(self, data=None, n_steps=None): if data is None: data = {} - if not isinstance(data, (list, tuple, dict, np.ndarray, tf.Tensor)): + if not isinstance(data, (list, tuple, dict, np.ndarray)) and not tf.is_tensor( + data + ): # data is some kind of generator, so we don't try to modify it (too many # different types of generators this could be) if n_steps is not None: diff --git a/nengo_dl/tensor_node.py b/nengo_dl/tensor_node.py index 83e6b9609..16d96041f 100644 --- a/nengo_dl/tensor_node.py +++ b/nengo_dl/tensor_node.py @@ -41,7 +41,7 @@ def validate_output(output, minibatch_size=None, output_d=None, dtype=None): Expected dtype of the function output. """ - if not isinstance(output, (tf.Tensor, tf.TensorSpec)): + if not isinstance(output, tf.TensorSpec) and not tf.is_tensor(output): raise ValidationError( "TensorNode function must return a Tensor (got %s)" % type(output), attr="tensor_func", diff --git a/nengo_dl/tests/test_simulator.py b/nengo_dl/tests/test_simulator.py index f1c8ff103..360d04c5b 100644 --- a/nengo_dl/tests/test_simulator.py +++ b/nengo_dl/tests/test_simulator.py @@ -837,7 +837,16 @@ def test_tensorboard(Simulator, tmpdir): @pytest.mark.parametrize("mode", ("predict", "train")) @pytest.mark.training -def test_profile(Simulator, mode, tmpdir): +def test_profile(Simulator, mode, tmpdir, pytestconfig): + if ( + pytestconfig.getoption("--graph-mode") + and version.parse(tf.__version__) >= version.parse("2.4.0rc0") + and mode == "predict" + ): + pytest.skip( + "TensorFlow bug, see https://github.com/tensorflow/tensorflow/issues/44563" + ) + net, a, p = dummies.linear_net() with Simulator(net) as sim: diff --git a/nengo_dl/tests/test_tensor_node.py b/nengo_dl/tests/test_tensor_node.py index 32be3b808..13bc2f76e 100644 --- a/nengo_dl/tests/test_tensor_node.py +++ b/nengo_dl/tests/test_tensor_node.py @@ -373,7 +373,7 @@ def __init__(self, expected): self.expected = expected def call(self, inputs, training=None): - tf.assert_equal(training, self.expected) + tf.assert_equal(tf.cast(training, tf.bool), self.expected) return tf.reshape(inputs, (1, 1)) with nengo.Network() as net: