Skip to content

Commit

Permalink
Support for TensorFlow 2.4
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Nov 3, 2020
1 parent 78e30e8 commit 1ccfbee
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Release history
- Added support for `KerasSpiking <https://www.nengo.ai/keras-spiking/>`_ layers in
the Converter. (`#182`_)
- Added support for ``tf.keras.layers.TimeDistributed`` in the Converter. (`#182`_)
- Added support for TensorFlow 2.4. (`#185`_)

**Changed**

Expand All @@ -47,6 +48,7 @@ Release history
.. _#173: https://github.com/nengo/nengo-dl/pull/173
.. _#181: https://github.com/nengo/nengo-dl/pull/181
.. _#182: https://github.com/nengo/nengo-dl/pull/182
.. _#185: https://github.com/nengo/nengo-dl/pull/185

3.3.0 (August 14, 2020)
-----------------------
Expand Down
4 changes: 1 addition & 3 deletions nengo_dl/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,7 @@ def _get_key(self, key):
# get output tensor
key = key.output

key = tuple(
x.ref() if isinstance(x, tf.Tensor) else x for x in tf.nest.flatten(key)
)
key = tuple(x.ref() if tf.is_tensor(x) else x for x in tf.nest.flatten(key))

return key

Expand Down
4 changes: 2 additions & 2 deletions nengo_dl/neuron_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from nengo.neurons import RectifiedLinear, SpikingRectifiedLinear, Sigmoid, LIF, LIFRate
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.framework import smart_cond

from nengo_dl import compat, utils
from nengo_dl.builder import Builder, OpBuilder
Expand Down Expand Up @@ -188,7 +188,7 @@ def build_step(self, signals, **step_kwargs):
out = step_output
else:
out = tf.nest.flatten(
tf_utils.smart_cond(
smart_cond.smart_cond(
self.config.training,
true_fn=lambda: (self.training_step(J, signals.dt, **state),)
+ tuple(state.values()),
Expand Down
4 changes: 3 additions & 1 deletion nengo_dl/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,7 +1815,9 @@ def _generate_inputs(self, data=None, n_steps=None):
if data is None:
data = {}

if not isinstance(data, (list, tuple, dict, np.ndarray, tf.Tensor)):
if not isinstance(data, (list, tuple, dict, np.ndarray)) and not tf.is_tensor(
data
):
# data is some kind of generator, so we don't try to modify it (too many
# different types of generators this could be)
if n_steps is not None:
Expand Down
2 changes: 1 addition & 1 deletion nengo_dl/tensor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def validate_output(output, minibatch_size=None, output_d=None, dtype=None):
Expected dtype of the function output.
"""

if not isinstance(output, (tf.Tensor, tf.TensorSpec)):
if not isinstance(output, tf.TensorSpec) and not tf.is_tensor(output):
raise ValidationError(
"TensorNode function must return a Tensor (got %s)" % type(output),
attr="tensor_func",
Expand Down
11 changes: 10 additions & 1 deletion nengo_dl/tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,16 @@ def test_tensorboard(Simulator, tmpdir):

@pytest.mark.parametrize("mode", ("predict", "train"))
@pytest.mark.training
def test_profile(Simulator, mode, tmpdir):
def test_profile(Simulator, mode, tmpdir, pytestconfig):
if (
pytestconfig.getoption("--graph-mode")
and version.parse(tf.__version__) >= version.parse("2.4.0rc0")
and mode == "predict"
):
pytest.skip(
"TensorFlow bug, see https://github.com/tensorflow/tensorflow/issues/44563"
)

net, a, p = dummies.linear_net()

with Simulator(net) as sim:
Expand Down
2 changes: 1 addition & 1 deletion nengo_dl/tests/test_tensor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def __init__(self, expected):
self.expected = expected

def call(self, inputs, training=None):
tf.assert_equal(training, self.expected)
tf.assert_equal(tf.cast(training, tf.bool), self.expected)
return tf.reshape(inputs, (1, 1))

with nengo.Network() as net:
Expand Down

0 comments on commit 1ccfbee

Please sign in to comment.