Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when using BatchNormalization Layer in TensorNode #109

Closed
drasmuss opened this issue Nov 12, 2019 · 0 comments · Fixed by #163
Closed

Error when using BatchNormalization Layer in TensorNode #109

drasmuss opened this issue Nov 12, 2019 · 0 comments · Fixed by #163

Comments

@drasmuss
Copy link
Member

Trying to optimize a model that contains BatchNormalization Layers inside TensorNodes results in an error. E.g.,

import tensorflow as tf
import nengo
import nengo_dl
import numpy as np

with nengo.Network() as net:
    a = nengo.Node([0])
    b = nengo_dl.Layer(tf.keras.layers.BatchNormalization())(a)
    p = nengo.Probe(b)

with nengo_dl.Simulator(net) as sim:
    sim.compile(optimizer=tf.optimizers.SGD(0), loss=tf.losses.mse)
    sim.fit(np.ones((1, 1, 1)), np.ones((1, 1, 1)))

tensorflow.python.framework.errors_impl.InvalidArgumentError: {{node training_1/group_deps}} has inputs from different frames. The input {{node TensorGraph/while/iteration_0/SimTensorNodeBuilder/cond_2/Merge}} is in frame 'TensorGraph/while/while_context'. The input {{node loss/mul}} is in frame ''.

I'd guess that using BatchNormalization layers inside any TensorFlow while loop results in the same error, but haven't looked into making a minimal example yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging a pull request may close this issue.

1 participant