diff --git a/CHANGES.rst b/CHANGES.rst index 2b3f9fbfe..e3625f9f7 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -37,6 +37,9 @@ Release history - Added a new ``synapse`` argument to the Converter, which can be used to automatically add synaptic filters on the output of neural layers during the conversion process. (`#141`_) +- Added a `new example `__ + demonstrating how to use the NengoDL Converter to convert a Keras model to a spiking + Nengo network. (`#141`_) **Changed** diff --git a/README.rst b/README.rst index be3c6faf8..24b089a82 100644 --- a/README.rst +++ b/README.rst @@ -60,8 +60,9 @@ adds a number of unique features, such as: - optimizing the parameters of a model through deep learning training methods (using the Keras API) - faster simulation speed, on both CPU and GPU -- inserting networks defined using TensorFlow (such as - deep learning architectures) directly into a Nengo model +- automatic conversion from Keras models to Nengo networks +- inserting TensorFlow code (individual functions or whole + network architectures) directly into a Nengo model **Documentation** diff --git a/docs/examples.rst b/docs/examples.rst index 62c8de887..ff2fd0ee2 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -33,6 +33,7 @@ These examples illustrate some different possible use cases for NengoDL: examples/tensorflow-models examples/spiking-mnist + examples/keras-to-snn examples/lmu examples/spa-retrieval examples/spa-memory diff --git a/docs/examples/keras-to-snn.ipynb b/docs/examples/keras-to-snn.ipynb new file mode 100644 index 000000000..d022e35c9 --- /dev/null +++ b/docs/examples/keras-to-snn.ipynb @@ -0,0 +1,594 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Converting a Keras model to a spiking neural network\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nengo/nengo-dl/blob/master/docs/examples/keras-to-snn.ipynb)\n", + "\n", + "A key feature of NengoDL is the ability to convert non-spiking networks into spiking networks. We can build both spiking and non-spiking networks in NengoDL, but often we may have an existing non-spiking network defined in a framework like Keras that we want to convert to a spiking network. The [NengoDL Converter](https://www.nengo.ai/nengo-dl/converter.html) is designed to assist in that kind of translation. By default, the converter takes in a Keras model and outputs an exactly equivalent Nengo network (so the Nengo network will be non-spiking). However, the converter can also apply various transformations during this conversion process, in particular aimed at converting a non-spiking Keras model into a spiking Nengo model.\n", + "\n", + "The goal of this notebook is to familiarize you with the process of converting a Keras network to a spiking neural network. Swapping to spiking neurons is a significant change to a model, which will have far-reaching impacts on the model's behaviour; we cannot simply change the neuron type and expect the model to perform the same without making any other changes to the model. This example will walk through some steps to take to help tune a spiking model to more closely match the performance of the original non-spiking network." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "from urllib.request import urlretrieve\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import nengo\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "\n", + "import nengo_dl\n", + "\n", + "seed = 0\n", + "np.random.seed(seed)\n", + "tf.random.set_seed(seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example we'll use the standard [MNIST dataset](http://yann.lecun.com/exdb/mnist/)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "(train_images, train_labels), (test_images, test_labels) = (\n", + " tf.keras.datasets.mnist.load_data())\n", + "\n", + "# flatten images and add time dimension\n", + "train_images = train_images.reshape((train_images.shape[0], 1, -1))\n", + "train_labels = train_labels.reshape((train_labels.shape[0], 1, -1))\n", + "test_images = test_images.reshape((test_images.shape[0], 1, -1))\n", + "test_labels = test_labels.reshape((test_labels.shape[0], 1, -1))\n", + "\n", + "plt.figure(figsize=(12, 4))\n", + "for i in range(3):\n", + " plt.subplot(1, 3, i+1)\n", + " plt.imshow(np.reshape(train_images[i], (28, 28)), cmap=\"gray\")\n", + " plt.axis('off')\n", + " plt.title(str(train_labels[i, 0, 0]));" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Converting a Keras model to a Nengo network\n", + "\n", + "Next we'll build a simple convolutional network. This architecture is chosen to be a quick and easy solution for this task; other tasks would likely require a different architecture, but the same general principles will apply." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# input\n", + "inp = tf.keras.Input(shape=(28, 28, 1))\n", + "\n", + "# convolutional layers\n", + "conv0 = tf.keras.layers.Conv2D(\n", + " filters=32,\n", + " kernel_size=3,\n", + " activation=tf.nn.relu,\n", + ")(inp)\n", + "\n", + "conv1 = tf.keras.layers.Conv2D(\n", + " filters=64,\n", + " kernel_size=3,\n", + " strides=2,\n", + " activation=tf.nn.relu,\n", + ")(conv0)\n", + "\n", + "# fully connected layer\n", + "flatten = tf.keras.layers.Flatten()(conv1)\n", + "dense = tf.keras.layers.Dense(units=10)(flatten)\n", + "\n", + "model = tf.keras.Model(inputs=inp, outputs=dense)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the Keras model is created, we can pass it into the NengoDL Converter. The Converter tool is designed to automate the translation from Keras to Nengo as much as possible. You can see the full list of arguments the Converter accepts in the [documentation](https://www.nengo.ai/nengo-dl/reference.html?highlight=converter#nengo_dl.Converter)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "converter = nengo_dl.Converter(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we are ready to train the network. It's important to note that we are using standard (non-spiking) ReLU neurons at this point.\n", + "\n", + "To make this example run a bit more quickly we've provided some pre-trained weights that will be downloaded below; set `do_training=True` to run the training yourself." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "do_training = False\n", + "if do_training:\n", + " with nengo_dl.Simulator(converter.net, minibatch_size=200) as sim:\n", + " # run training\n", + " sim.compile(\n", + " optimizer=tf.optimizers.RMSprop(0.001),\n", + " loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " metrics=[tf.metrics.sparse_categorical_accuracy],\n", + " )\n", + " sim.fit(\n", + " {converter.inputs[inp]: train_images},\n", + " {converter.outputs[dense]: train_labels},\n", + " validation_data=(\n", + " {converter.inputs[inp]: test_images},\n", + " {converter.outputs[dense]: test_labels},\n", + " ),\n", + " epochs=2,\n", + " )\n", + "\n", + " # save the parameters to file\n", + " sim.save_params(\"./keras_to_snn_params\")\n", + "else:\n", + " # download pretrained weights\n", + " urlretrieve(\n", + " \"https://drive.google.com/uc?export=download&\"\n", + " \"id=1lBkR968AQo__t8sMMeDYGTQpBJZIs2_T\",\n", + " \"keras_to_snn_params.npz\")\n", + " print(\"Loaded pretrained weights\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After training for 2 epochs the non-spiking network is achieving ~98% accuracy on the test data, which is what we'd expect for a network this simple.\n", + "\n", + "Now that we have our trained weights, we can begin the conversion to spiking neurons. To help us in this process we're going to first define a helper function that will build the network for us, load weights from a specified file, and make it easy to play around with some other features of the network." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "def run_network(activation, params_file=\"keras_to_snn_params\", n_steps=30,\n", + " scale_firing_rates=1, synapse=None, n_test=400):\n", + " # convert the keras model to a nengo network\n", + " nengo_converter = nengo_dl.Converter(\n", + " model,\n", + " swap_activations={tf.nn.relu: activation},\n", + " scale_firing_rates=scale_firing_rates,\n", + " synapse=synapse,\n", + " )\n", + " \n", + " # get input/output objects\n", + " nengo_input = nengo_converter.inputs[inp]\n", + " nengo_output = nengo_converter.outputs[dense]\n", + " \n", + " # add a probe to the first convolutional layer to record activity\n", + " with nengo_converter.net:\n", + " conv0_probe = nengo.Probe(nengo_converter.layers[conv0])\n", + " \n", + " # repeat inputs for some number of timesteps\n", + " tiled_test_images = np.tile(test_images[:n_test], (1, n_steps, 1))\n", + " \n", + " # set some options to speed up simulation\n", + " with nengo_converter.net:\n", + " nengo_dl.configure_settings(stateful=False)\n", + " \n", + " # build network, load in trained weights, run inference on test images\n", + " with nengo_dl.Simulator(\n", + " nengo_converter.net, minibatch_size=10, \n", + " progress_bar=False) as nengo_sim:\n", + " nengo_sim.load_params(params_file)\n", + " data = nengo_sim.predict({nengo_input: tiled_test_images})\n", + " \n", + " # compute accuracy on test data, using output of network on\n", + " # last timestep\n", + " predictions = np.argmax(data[nengo_output][:, -1], axis=-1)\n", + " accuracy = (predictions == test_labels[:n_test, 0, 0]).mean()\n", + " print(\"Test accuracy: %.2f%%\" % (100 * accuracy))\n", + " \n", + " # plot the results\n", + " for ii in range(4):\n", + " plt.figure(figsize=(12, 4))\n", + " \n", + " plt.subplot(1, 3, 1)\n", + " plt.title(\"Input image\")\n", + " plt.imshow(test_images[ii, 0].reshape((28, 28)), cmap=\"gray\")\n", + " plt.axis('off')\n", + " \n", + " plt.subplot(1, 3, 2)\n", + " sample_neurons = np.linspace(\n", + " 0,\n", + " data[conv0_probe].shape[-1], \n", + " 1000,\n", + " endpoint=False,\n", + " dtype=np.int32,\n", + " )\n", + " scaled_data = data[conv0_probe][ii, :, sample_neurons].T * scale_firing_rates\n", + " if isinstance(activation, nengo.SpikingRectifiedLinear):\n", + " scaled_data *= 0.001\n", + " rates = np.sum(scaled_data, axis=0) / (n_steps * nengo_sim.dt)\n", + " plt.ylabel('Number of spikes')\n", + " else:\n", + " rates = scaled_data\n", + " plt.ylabel('Firing rates (Hz)')\n", + " plt.xlabel('Timestep')\n", + " plt.title(\n", + " \"Neural activities (conv0 mean=%dHz max=%dHz)\" % (\n", + " rates.mean(), rates.max())\n", + " )\n", + " plt.plot(scaled_data)\n", + " \n", + " plt.subplot(1, 3, 3)\n", + " plt.title(\"Output predictions\")\n", + " plt.plot(tf.nn.softmax(data[nengo_output][ii]))\n", + " plt.legend([str(j) for j in range(10)], loc=\"upper left\")\n", + " plt.xlabel('Timestep')\n", + " plt.ylabel(\"Probability\")\n", + " \n", + " plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now to run our trained network all we have to do is:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "run_network(activation=nengo.RectifiedLinear(), n_steps=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that we're plotting the output over time for consistency with future plots, but since our network doesn't have any temporal elements (e.g. spiking neurons), the output is constant for each digit." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Converting to a spiking neural network\n", + "\n", + "Now that we have the non-spiking version working in Nengo, we can start converting the network into spikes. Using the NengoDL converter, we can swap all the `relu` activation functions to `nengo.SpikingRectifiedLinear`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "run_network(activation=nengo.SpikingRectifiedLinear(), n_steps=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this naive conversion we are getting random accuracy (~10%), which indicates that the network is not functioning well. Next, we will look at various steps we can take to improve the performance of the spiking model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Presentation time\n", + "\n", + "If we look at the neural activity plots above, we can see one thing that's going wrong: the activities are all zero! (The non-zero final output is just a result of the internal biases). Referring back to the neural activity plot from our non-spiking network further up, we can gain a bit of insight into why this occurs. We can see that the firing rates are all below 100 Hz. 100 Hz means that a neuron is emitting approximately 1 spike every 10 timesteps (given the simulator timestep of 1ms). We're simulating for 10 time steps for each image, so we wouldn't really expect many of our neurons to be spiking within that 10 timestep window. If we present each image for longer we should start seeing some activity." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "run_network(\n", + " activation=nengo.SpikingRectifiedLinear(),\n", + " n_steps=50,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see now that while initially there's no network activity, eventually we do start getting some spikes. Note that although we start seeing spikes in the `conv0` layer around the 10th timestep, we don't start seeing activity in the output layer until around the 40th timestep. That is because each layer in the network is adding a similar delay as we see in `conv0`, so when you put those all together in series it takes time for the activity to propagate through to the final output layer." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Synaptic smoothing\n", + "\n", + "Even with the increased presentation time, the test accuracy is still very low. This is because, as we can see in the output prediction plots, the network output is very noisy. Spikes are discrete events that exist for only a single time step and then disappear; we can see the literal \"spikes\" in the plots. Even if the neuron corresponding to the correct output is spiking quite rapidly, it's still not guaranteed that it will spike on exactly the last timestep (which is when we are checking the test accuracy).\n", + "\n", + "One way that we can compensate for this rapid fluctuation in the network output is to apply some smoothing to the spikes. This can be achieved in Nengo through the use of synaptic filters. The default `synapse` used in Nengo is a low-pass filter, and when we specify a value for the `synapse` parameter, that value is used as the low-pass filter time constant. When we pass a `synapse` value in the `run_network` function, it will create a low-pass filter with that time constant on the output of all the spiking neurons. \n", + "\n", + "Intuitively, we can think of this as computing a running average of each neuron's activity over a short window of time (rather than just looking at the spikes on the last timestep).\n", + "\n", + "Below we show results from the network running with three different low-pass filters. Note that adding synaptic filters will further increase the delay before neurons start spiking, because the filters will add their own \"ramp up\" time on each layer. So we'll run the network for even longer in these tests." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "for s in [0.001, 0.005, 0.01]:\n", + " print(\"Synapse=%.3f\" % s)\n", + " run_network(\n", + " activation=nengo.SpikingRectifiedLinear(),\n", + " n_steps=120,\n", + " synapse=s,\n", + " )\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that adding synaptic filtering smooths the output of the model and thereby improves the accuracy. With `synapse=0.01` we're achieving ~80% test accuracy; still not great, but significantly better than what we started with.\n", + "\n", + "However, increasing the magnitude of the synaptic filtering also increases the latency before we start seeing output activity. We can see that with `synapse=0.01` we don't start seeing output activity until around the 70th timestep. This means that with more synaptic filtering we have to present the input images for a longer period of time, which takes longer to simulate and adds more latency to the model's predictions. This is a common tradeoff in spiking networks (latency versus accuracy)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Firing rates\n", + "\n", + "Another way that we can improve network performance is by increasing the firing rates of the neurons. Neurons that spike more frequently update their output signal more often. This means that as firing rates increase, the behaviour of the spiking model will more closely match the original non-spiking model (where the neuron is directly outputting its true firing rate every timestep).\n", + "\n", + "#### Post-training scaling\n", + "\n", + "We can increase firing rates without retraining the model by applying a linear scale to the input of all the neurons (and then dividing their output by the same scale factor). Note that because we're applying a linear scale to the input and output, this will likely only work well with linear activation functions (like ReLU). To apply this scaling using the NengoDL Converter, we can use the `scale_firing_rates` parameter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "for scale in [2, 10, 20]:\n", + " print(\"Scale=%d\" % scale)\n", + " run_network(\n", + " activation=nengo.SpikingRectifiedLinear(),\n", + " scale_firing_rates=scale,\n", + " synapse=0.01\n", + " )\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that as the frequency of spiking increases, the accuracy also increases. And we're able to achieve good accuracy (very close to the original non-spiking network) without adding too much latency.\n", + "\n", + "Note that if we increase the firing rates enough, the spiking model eventually becomes equivalent to a non-spiking model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run_network(\n", + " activation=nengo.SpikingRectifiedLinear(),\n", + " scale_firing_rates=1000,\n", + " n_steps=10\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "While this looks good from an accuracy perspective, it also means that we have lost many of the advantages of a spiking model (e.g. sparse communication, as indicated by the very high firing rates). This is another common tradeoff (accuracy versus firing rates) that can be customized depending on the demands of a particular application." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Regularizing during training\n", + "\n", + "Rather than using `scale_firing_rates` to upscale the firing rates after training, we can also directly optimize the firing rates during training. We'll add loss functions that compute the mean squared error (MSE) between the output activity of each of the convolutional layers and some target firing rates we specify. We can think of this as applying L2 regularization to the firing rates, but we've shifted the regularization point from 0 to some target value. One of the benefits of this method is that it is also effective for neurons with non-linear activation functions, such as LIF neurons." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# we'll encourage the neurons to spike at around 250Hz\n", + "target_rate = 250\n", + "\n", + "# convert keras model to nengo network\n", + "converter = nengo_dl.Converter(model)\n", + "\n", + "# add probes to the convolutional layers, which\n", + "# we'll use to apply the firing rate regularization\n", + "with converter.net:\n", + " output_p = converter.outputs[dense]\n", + " conv0_p = nengo.Probe(converter.layers[conv0])\n", + " conv1_p = nengo.Probe(converter.layers[conv1])\n", + "\n", + "with nengo_dl.Simulator(converter.net, minibatch_size=200) as sim:\n", + " # add regularization loss functions to the convolutional layers\n", + " sim.compile(\n", + " optimizer=tf.optimizers.RMSprop(0.001),\n", + " loss={\n", + " output_p: tf.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " conv0_p: tf.losses.mse,\n", + " conv1_p: tf.losses.mse,\n", + " },\n", + " loss_weights={output_p: 1, conv0_p: 1e-3, conv1_p: 1e-3}\n", + " )\n", + "\n", + " do_training = False\n", + " if do_training:\n", + " # run training (specifying the target rates for the convolutional layers)\n", + " sim.fit(\n", + " {converter.inputs[inp]: train_images}, \n", + " {\n", + " output_p: train_labels, \n", + " conv0_p: np.ones((train_labels.shape[0], 1, conv0_p.size_in)) \n", + " * target_rate,\n", + " conv1_p: np.ones((train_labels.shape[0], 1, conv1_p.size_in)) \n", + " * target_rate,\n", + " },\n", + " epochs=10)\n", + "\n", + " # save the parameters to file\n", + " sim.save_params(\"./keras_to_snn_regularized_params\")\n", + " else:\n", + " # download pretrained weights\n", + " urlretrieve(\n", + " \"https://drive.google.com/uc?export=download&\"\n", + " \"id=1xvIIIQjiA4UM9Mg_4rq_ttBH3wIl0lJx\",\n", + " \"keras_to_snn_regularized_params.npz\")\n", + " print(\"Loaded pretrained weights\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can examine the firing rates in the non-spiking network." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run_network(\n", + " activation=nengo.RectifiedLinear(),\n", + " params_file=\"keras_to_snn_regularized_params\",\n", + " n_steps=10,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the neuron activity plot we can see that the firing rates are around the magnitude we specified (we could adjust the regularization function/weighting to refine this further). Now we can convert it to spiking neurons, without applying any scaling." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run_network(\n", + " activation=nengo.SpikingRectifiedLinear(),\n", + " params_file=\"keras_to_snn_regularized_params\",\n", + " synapse=0.01,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that this network, because we trained it with spiking neurons in mind, can be converted to a spiking network without losing much performance or requiring any further tweaking." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusions\n", + "\n", + "In this example we've gone over the process of converting a non-spiking Keras model to a spiking Nengo network. We've shown some of the common issues that crop up, and how to go about diagnosing/addressing them. In particular, we looked at presentation time, synaptic filtering, and firing rates, and how adjusting those factors can affect various properties of the model (such as accuracy, latency, and temporal sparsity). Note that a lot of these factors represent tradeoffs that are application dependent. The particular parameters that we used in this example may not work or make sense in other applications, but this same workflow and thought process should apply to converting any kind of network to a spiking Nengo model." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}