You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This popped up trying to drop in an LMU cell wrapped in a keras RNN layer. I've reproduced it here with the standard LSTM cell and mnist data for convenience. This may be somewhat related to #122 as it uses that method of reshaping data.
importnengofromnengo.utils.filter_designimportcont2discreteimportnumpyasnpimporttensorflowastfimportnengo_dlimportos;
os.environ['CUDA_VISIBLE_DEVICES'] ='0'# GPU# set seed to ensure this example is reproducibleseed=0tf.random.set_seed(seed)
np.random.seed(seed)
rng=np.random.RandomState(seed)
# ### Data# load mnist dataset
(train_images, train_labels), (test_images, test_labels) = (
tf.keras.datasets.mnist.load_data())
# change inputs to 0--1 rangetrain_images=train_images/255test_images=test_images/255# reshape the labels to rank 3 (as expected in Nengo)train_labels=train_labels[:, None, None]
test_labels=test_labels[:, None, None]
# ### Networkwithnengo.Network(seed=seed) asnet:
nengo_dl.configure_settings(
trainable=None, stateful=False, keep_history=False,
)
inp=nengo.Node(np.zeros(np.prod(train_images.shape[1:]))) # flattened inputsinp_dropout=nengo_dl.Layer(tf.keras.layers.Dropout(rate=0.1))(inp)
h=nengo_dl.Layer(tf.keras.layers.RNN(tf.keras.layers.LSTMCell(units=128)))(
inp_dropout, shape_in=train_images.shape[1:])
out=nengo_dl.Layer(tf.keras.layers.Dense(units=10))(h)
p=nengo.Probe(out)
# ### Trainingtrigger_error=Truewithnengo_dl.Simulator(
net, minibatch_size=100) assim:
sim.compile(
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.optimizers.Adam(),
metrics=["accuracy"],
)
# This sim.evaluate seems to trigger the problem, works with this commented out?!# Successfully prints accuracy then attempts to begin training and failsiftrigger_error:
print(
"Initial test accuracy: %.2f%%"% (sim.evaluate(
{inp: test_images.reshape((test_images.shape[0], 1, -1))},
{p: test_labels}, verbose=0)["probe_accuracy"] *100
)
)
sim.fit(
{inp: train_images.reshape((train_images.shape[0], 1, -1))},
{p: train_labels},
epochs=1,
)
It looks like the sim.evaluate call to get the initial test accuracy is the culprit.
If you don't do this evaluation (i.e., set trigger_error = False) then this works as expected.
If the error is triggered, it will continue to fail even if you set trigger_error = False for a subsequent attempts.
Some extra info:
This runs successfully on the CPU
I've also tried this on a couple machines to rule out hardware issues
I've tried a few things not captured in the minimal example, none of which had an effect:
tf.config.set_soft_device_placement(True)
Defining the network inside with tf.device("/gpu:0"):
Defining the Simulator with device="/gpu:0"
My environment is as follows:
Ubuntu 16.04
Miniconda 4.8.3
Python 3.8.3
tensorflow==2.2.0
tensorflow-gpu==2.2.0
nengo==3.0.0
nengo-dl==3.2.0
numpy==1.18.1
cudatoolkit 10.1.243
cudnn 7.6.5
The output is very long, so I opted to attach it as a text file instead, but this is what I believe is the core error (note you can scroll horizontally):
InvalidArgumentError: Cannot assign a device for operation TensorGraph/while/iteration_0/SimTensorNodeBuilder_1/while/lstm_cell/mul_2: Could not satisfy explicit device specification '' because the node {{colocation_node TensorGraph/while/iteration_0/SimTensorNodeBuilder_1/while/lstm_cell/mul_2}} was colocated with a group of nodes that required incompatible device '/job:localhost/replica:0/task:0/device:GPU:0'. All available devices [/job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:XLA_CPU:0, /job:localhost/replica:0/task:0/device:GPU:0, /job:localhost/replica:0/task:0/device:XLA_GPU:0].
Colocation Debug Info:
Colocation group had the following types and supported devices:
Root Member(assigned_device_name_index_=1 requested_device_name_='/job:localhost/replica:0/task:0/device:GPU:0' assigned_device_name_='/job:localhost/replica:0/task:0/device:GPU:0' resource_device_name_='/job:localhost/replica:0/task:0/device:GPU:0' supported_device_types_=[CPU] possible_devices_=[]
TensorArrayReadV3: GPU CPU XLA_CPU XLA_GPU
TensorArrayGradV3: GPU CPU XLA_CPU XLA_GPU
Exit: GPU CPU XLA_CPU XLA_GPU
TensorArrayWriteV3: GPU CPU XLA_CPU XLA_GPU
Enter: GPU CPU XLA_CPU XLA_GPU
Mul: GPU CPU XLA_CPU XLA_GPU
Const: GPU CPU XLA_CPU XLA_GPU
Range: GPU CPU XLA_CPU XLA_GPU
Identity: GPU CPU XLA_CPU XLA_GPU
TensorArrayGatherV3: GPU CPU XLA_CPU XLA_GPU
TensorArrayScatterV3: GPU CPU XLA_CPU XLA_GPU
StackPopV2: CPU
TensorArrayV3: GPU CPU XLA_CPU XLA_GPU
StackPushV2: CPU
StackV2: GPU CPU XLA_CPU XLA_GPU
In this minimal example we see a couple items (StackPopV2 and StackPushV2) that show CPU only, that was not the case with the error in my original context.
This is maybe related, so I'm posting here instead of opening a separate issue, but in order for GPU to be used for my model (that uses nengo only) I needed to change this line:
This popped up trying to drop in an LMU cell wrapped in a keras RNN layer. I've reproduced it here with the standard LSTM cell and mnist data for convenience. This may be somewhat related to #122 as it uses that method of reshaping data.
It looks like the
sim.evaluate
call to get the initial test accuracy is the culprit.If you don't do this evaluation (i.e., set
trigger_error = False
) then this works as expected.If the error is triggered, it will continue to fail even if you set
trigger_error = False
for a subsequent attempts.Some extra info:
tf.config.set_soft_device_placement(True)
with tf.device("/gpu:0"):
device="/gpu:0"
My environment is as follows:
The output is very long, so I opted to attach it as a text file instead, but this is what I believe is the core error (note you can scroll horizontally):
In this minimal example we see a couple items (
StackPopV2
andStackPushV2
) that show CPU only, that was not the case with the error in my original context.output.txt
The text was updated successfully, but these errors were encountered: