From 08595b26ff5ac3f78f641e5aa5f511993db42162 Mon Sep 17 00:00:00 2001 From: markbookk Date: Mon, 1 Mar 2021 08:08:10 -0400 Subject: [PATCH 01/20] Adding section 8.5, not finished yet --- .../rnn-scratch.ipynb | 916 ++++++++++++++++++ utils/Functions.java | 7 + 2 files changed, 923 insertions(+) create mode 100644 chapter_recurrent-neural-networks/rnn-scratch.ipynb diff --git a/chapter_recurrent-neural-networks/rnn-scratch.ipynb b/chapter_recurrent-neural-networks/rnn-scratch.ipynb new file mode 100644 index 00000000..fef28fbe --- /dev/null +++ b/chapter_recurrent-neural-networks/rnn-scratch.ipynb @@ -0,0 +1,916 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 0 + }, + "source": [ + "# Implementation of Recurrent Neural Networks from Scratch\n", + ":label:`sec_rnn_scratch`\n", + "\n", + "In this section we will implement an RNN\n", + "from scratch\n", + "for a character-level language model,\n", + "according to our descriptions\n", + "in :numref:`sec_rnn`.\n", + "Such a model\n", + "will be trained on H. G. Wells' *The Time Machine*.\n", + "As before, we start by reading the dataset first, which is introduced in :numref:`sec_language_model`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", + "%maven org.slf4j:slf4j-api:1.7.26\n", + "%maven org.slf4j:slf4j-simple:1.7.26\n", + "%maven net.java.dev.jna:jna:5.6.0\n", + " \n", + "%maven ai.djl:api:0.11.0-SNAPSHOT\n", + "%maven ai.djl:basicdataset:0.11.0-SNAPSHOT\n", + "\n", + "// See https://github.com/awslabs/djl/blob/master/mxnet/mxnet-engine/README.md\n", + "// MXNet \n", + "%maven ai.djl.mxnet:mxnet-engine:0.11.0-SNAPSHOT\n", + "%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport\n", + "\n", + "// Tensorflow\n", + "// %maven org.bytedeco:javacpp:1.5.4" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "// %%loadFromPOM\n", + "\n", + "// \n", + "// org.bytedeco\n", + "// javacv-platform\n", + "// 1.5.4\n", + "// \n", + "\n", + "// \n", + "// com.google.protobuf\n", + "// protobuf-java\n", + "// 3.8.0\n", + "// " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "// // MXNET\n", + "// // List addedJars = %jars /Users/nivesmn/Documents/projects/djl/api/build/libs/ /Users/nivesmn/Documents/projects/djl/basicdataset/build/libs/ /Users/nivesmn/Documents/projects/djl/mxnet/mxnet-engine/build/libs/ /Users/nivesmn/Documents/projects/djl/mxnet/native/build/libs/\n", + "// List addedJars = %jars /Users/nivesmn/Documents/projects/djl/api/build/libs/*SNAPSHOT.jar /Users/nivesmn/Documents/projects/djl/basicdataset/build/libs/*SNAPSHOT.jar /Users/nivesmn/Documents/projects/djl/mxnet/mxnet-engine/build/libs/*SNAPSHOT.jar /Users/nivesmn/Documents/projects/djl/mxnet/native/build/libs/*SNAPSHOT.jar\n", + "\n", + "// // TF\n", + "// // List addedJars = %jars /Users/nivesmn/Documents/projects/djl/api/build/libs/ /Users/nivesmn/Documents/projects/djl/basicdataset/build/libs/ /Users/nivesmn/Documents/projects/djl/tensorflow/tensorflow-engine/build/libs/ /Users/nivesmn/Documents/projects/djl/tensorflow/tensorflow-native/build/libs/ /Users/nivesmn/Documents/projects/djl/tensorflow/tensorflow-api/build/libs/ \n", + "// // List addedJars = %jars /Users/nivesmn/Documents/projects/djl/api/build/libs/*SNAPSHOT.jar /Users/nivesmn/Documents/projects/djl/basicdataset/build/libs/*SNAPSHOT.jar /Users/nivesmn/Documents/projects/djl/tensorflow/tensorflow-engine/build/libs/*SNAPSHOT.jar /Users/nivesmn/Documents/projects/djl/tensorflow/tensorflow-native/build/libs/*SNAPSHOT.jar /Users/nivesmn/Documents/projects/djl/tensorflow/tensorflow-api/build/libs/*SNAPSHOT.jar\n", + "\n", + "// // PT\n", + "// // List addedJars = %jars /Users/nivesmn/Documents/projects/djl/api/build/libs/ /Users/nivesmn/Documents/projects/djl/basicdataset/build/libs/ /Users/nivesmn/Documents/projects/djl/./pytorch/pytorch-engine/build/libs/ /Users/nivesmn/Documents/projects/djl/./pytorch/pytorch-native/build/libs/\n", + " \n", + "// System.out.println(\"Size: \" + addedJars.size());\n", + "// addedJars" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ai.djl.ndarray.*;\n", + "import ai.djl.ndarray.NDList;\n", + "import ai.djl.ndarray.types.*;\n", + "import ai.djl.ndarray.index.*;\n", + "import ai.djl.util.Pair;\n", + "import ai.djl.Device;\n", + "import ai.djl.training.loss.*;\n", + "import ai.djl.training.*;\n", + "import ai.djl.engine.*;" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load ../utils/plot-utils\n", + "%load ../utils/Functions.java\n", + "%load ../utils/PlotUtils.java\n", + "%load ../utils/TimeMachineUtils.java\n", + "%load ../utils/StopWatch.java\n", + "%load ../utils/Accumulator.java\n", + "%load ../utils/Animator.java\n", + "%load ../utils/Training.java" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@FunctionalInterface\n", + "public interface TriFunction {\n", + " public W apply(T t, U u, V v);\n", + "}\n", + "\n", + "@FunctionalInterface\n", + "public interface QuadFunction {\n", + " public R apply(T t, U u, V v, W w);\n", + "}\n", + "\n", + "@FunctionalInterface\n", + "public interface SimpleFunction {\n", + " public T apply();\n", + "}\n", + "\n", + "@FunctionalInterface\n", + "public interface voidFunction {\n", + " public void apply(T t);\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "NDManager manager = NDManager.newBaseManager();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "int batchSize = 32;\n", + "int numSteps = 35;\n", + "Pair, Vocab> timeMachine = loadDataTimeMachine(batchSize, numSteps, false, 10000);\n", + "List trainIter = timeMachine.getKey();\n", + "Vocab vocab = timeMachine.getValue();" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 6 + }, + "source": [ + "## One-Hot Encoding\n", + "\n", + "Recall that each token is represented as a numerical index in `train_iter`.\n", + "Feeding these indices directly to a neural network might make it hard to\n", + "learn.\n", + "We often represent each token as a more expressive feature vector.\n", + "The easiest representation is called *one-hot encoding*,\n", + "which is introduced\n", + "in :numref:`subsec_classification-problem`.\n", + "\n", + "In a nutshell, we map each index to a different unit vector: assume that the number of different tokens in the vocabulary is $N$ (`len(vocab)`) and the token indices range from 0 to $N-1$.\n", + "If the index of a token is the integer $i$, then we create a vector of all 0s with a length of $N$ and set the element at position $i$ to 1.\n", + "This vector is the one-hot vector of the original token. The one-hot vectors with indices 0 and 2 are shown below.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "manager.create(new int[] {0, 2}).oneHot(vocab.length())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 10 + }, + "source": [ + "The shape of the minibatch that we sample each time is (batch size, number of time steps).\n", + "The `one_hot` function transforms such a minibatch into a three-dimensional tensor with the last dimension equals to the vocabulary size (`len(vocab)`).\n", + "We often transpose the input so that we will obtain an\n", + "output of shape\n", + "(number of time steps, batch size, vocabulary size).\n", + "This will allow us\n", + "to more conveniently\n", + "loop through the outermost dimension\n", + "for updating hidden states of a minibatch,\n", + "time step by time step.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "NDArray X = manager.arange(10).reshape(new Shape(2,5));\n", + "X.transpose().oneHot(28).getShape()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 14 + }, + "source": [ + "## Initializing the Model Parameters\n", + "\n", + "Next, we initialize the model parameters for\n", + "the RNN model.\n", + "The number of hidden units `num_hiddens` is a tunable hyperparameter.\n", + "When training language models,\n", + "the inputs and outputs are from the same vocabulary.\n", + "Hence, they have the same dimension,\n", + "which is equal to the vocabulary size.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "public NDList getParams(int vocabSize, int numHiddens, Device device) {\n", + " int numOutputs = vocabSize;\n", + " int numInputs = vocabSize;\n", + " \n", + " // Hidden layer parameters\n", + " NDArray W_xh = normal(new Shape(numInputs, numHiddens), device);\n", + " NDArray W_hh = normal(new Shape(numHiddens, numHiddens), device);\n", + " NDArray b_h = manager.zeros(new Shape(numHiddens), DataType.FLOAT32, device);\n", + " // Output layer parameters\n", + " NDArray W_hq = normal(new Shape(numHiddens, numOutputs), device);\n", + " NDArray b_q = manager.zeros(new Shape(numOutputs), DataType.FLOAT32, device);\n", + " \n", + " \n", + " // Attach gradients\n", + " NDList params = new NDList(W_xh, W_hh, b_h, W_hq, b_q);\n", + " for (NDArray param : params) {\n", + " param.attachGradient();\n", + " }\n", + " return params;\n", + "}\n", + "public NDArray normal(Shape shape, Device device) {\n", + " return manager.randomNormal(0f, 0.01f, shape, DataType.FLOAT32, device);\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 18 + }, + "source": [ + "## RNN Model\n", + "\n", + "To define an RNN model,\n", + "we first need an `init_rnn_state` function\n", + "to return the hidden state at initialization.\n", + "It returns a tensor filled with 0 and with a shape of (batch size, number of hidden units).\n", + "Using tuples makes it easier to handle situations where the hidden state contains multiple variables,\n", + "which we will encounter in later sections.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "public NDArray initRNNState(int batchSize, int numHiddens, Device device) {\n", + " return manager.zeros(new Shape(batchSize, numHiddens), DataType.FLOAT32, device);\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 22 + }, + "source": [ + "The following `rnn` function defines how to compute the hidden state and output\n", + "at a time step.\n", + "Note that\n", + "the RNN model\n", + "loops through the outermost dimension of `inputs`\n", + "so that it updates hidden states `H` of a minibatch,\n", + "time step by time step.\n", + "Besides,\n", + "the activation function here uses the $\\tanh$ function.\n", + "As\n", + "described in :numref:`sec_mlp`, the\n", + "mean value of the $\\tanh$ function is 0, when the elements are uniformly\n", + "distributed over the real numbers.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "public Pair rnn(NDArray inputs, NDArray state, NDList params) {\n", + " // Shape of `inputs`: (`numSteps`, `batchSize`, `vocabSize`)\n", + " NDArray W_xh = params.get(0);\n", + " NDArray W_hh = params.get(1);\n", + " NDArray b_h = params.get(2);\n", + " NDArray W_hq = params.get(3);\n", + " NDArray b_q = params.get(4);\n", + " NDArray H = state;\n", + " \n", + " NDList outputs = new NDList();\n", + " //Shape of `X`: (`batchSize`, `vocabSize`)\n", + " NDArray X, Y;\n", + "// System.out.println(\"inputs.getShape().getShape()[0] -> \" + inputs.getShape().getShape()[0]);\n", + " for(int i = 0; i < inputs.getShape().getShape()[0]; i++) {\n", + " X = inputs.get(new NDIndex(i));\n", + " H = (X.dot(W_xh).add(H.dot(W_hh)).add(b_h)).tanh();\n", + " Y = H.dot(W_hq).add(b_q);\n", + " outputs.add(Y);\n", + "// System.out.println(Y.toDebugString(100, 100, 100, 100));\n", + " }\n", + " return new Pair(outputs.size() > 1 ? NDArrays.concat(outputs) : outputs.get(0), H);\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 26 + }, + "source": [ + "With all the needed functions being defined,\n", + "next we create a class to wrap these functions and store parameters for an RNN model implemented from scratch.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "/**\n", + " * An RNN Model implemented from scratch. \n", + "*/\n", + "class RNNModelScratch {\n", + " public int vocabSize;\n", + " public int numHiddens;\n", + " public NDList params;\n", + " public TriFunction initState;\n", + " public TriFunction forwardFn;\n", + " \n", + " \n", + " public RNNModelScratch(int vocabSize, int numHiddens, Device device, \n", + " TriFunction getParams, \n", + " TriFunction initRNNState,\n", + " TriFunction forwardFn\n", + " ) {\n", + " this.vocabSize = vocabSize;\n", + " this.numHiddens = numHiddens;\n", + " this.params = getParams.apply(vocabSize, numHiddens, device);\n", + " this.initState = initRNNState;\n", + " this.forwardFn = forwardFn;\n", + " }\n", + " \n", + " public Pair call(NDArray X, NDArray state) {\n", + " X = X.transpose().oneHot(this.vocabSize);\n", + " return this.forwardFn.apply(X, state, this.params);\n", + " }\n", + " \n", + " public NDArray beginState(int batchSize, Device device) {\n", + " return this.initState.apply(batchSize, this.numHiddens, device);\n", + " }\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 30 + }, + "source": [ + "Let us check whether the outputs have the correct shapes, e.g., to ensure that the dimensionality of the hidden state remains unchanged.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "int numHiddens = 512;\n", + "TriFunction getParamsFn = (a, b, c) -> getParams(a, b, c);\n", + "TriFunction initRNNStateFn = (a, b, c) -> initRNNState(a, b, c);\n", + "TriFunction rnnFn = (a, b, c) -> rnn(a, b, c);\n", + "\n", + "RNNModelScratch net = new RNNModelScratch(vocab.length(), numHiddens, Functions.tryGpu(0), \n", + " getParamsFn, initRNNStateFn, rnnFn);\n", + "NDArray state = net.beginState((int) X.getShape().getShape()[0], Functions.tryGpu(0));\n", + "Pair pairResult = net.call(X.toDevice(Functions.tryGpu(0), false), state);\n", + "NDArray Y = pairResult.getKey();\n", + "NDArray newState = pairResult.getValue();\n", + "System.out.println(Y.getShape());\n", + "System.out.println(newState.getShape());" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 34 + }, + "source": [ + "We can see that the output shape is (number of time steps $\\times$ batch size, vocabulary size), while the hidden state shape remains the same, i.e., (batch size, number of hidden units).\n", + "\n", + "\n", + "## Prediction\n", + "\n", + "Let us first define the prediction function\n", + "to generate new characters following\n", + "the user-provided `prefix`,\n", + "which is a string containing several characters.\n", + "When looping through these beginning characters in `prefix`,\n", + "we keep passing the hidden state\n", + "to the next time step without\n", + "generating any output.\n", + "This is called the *warm-up* period,\n", + "during which the model updates itself\n", + "(e.g., update the hidden state)\n", + "but does not make predictions.\n", + "After the warm-up period,\n", + "the hidden state is generally better than\n", + "its initialized value at the beginning.\n", + "So we generate the predicted characters and emit them.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "/**\n", + " * Generate new characters following the `prefix`.\"\"\"\n", + " */\n", + "public String predictCh8(String prefix, int numPreds, \n", + " RNNModelScratch net,\n", + " Vocab vocab, Device device) {\n", + " NDArray state = net.beginState(1, device);\n", + " List outputs = new ArrayList<>();\n", + " outputs.add(vocab.getIdx(\"\" + prefix.charAt(0)));\n", + " SimpleFunction getInput = () -> manager.create(outputs.get(outputs.size()-1))\n", + " .toDevice(device, false).reshape(new Shape(1, 1));\n", + " for (char c : prefix.substring(1).toCharArray()) { // Warm-up period\n", + " state = (NDArray) net.call(getInput.apply(), state).getValue();\n", + " outputs.add(vocab.getIdx(\"\" + c));\n", + " }\n", + " \n", + " NDArray y;\n", + " for (int i = 0; i < numPreds; i++) {\n", + " Pair pair = net.call(getInput.apply(), state);\n", + " y = pair.getKey();\n", + " state = pair.getValue();\n", + "\n", + " outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));\n", + " }\n", + " String outputString = \"\";\n", + " for (int i : outputs) {\n", + " outputString += vocab.idxToToken.get(i);\n", + " }\n", + " return outputString;\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 38 + }, + "source": [ + "Now we can test the `predict_ch8` function.\n", + "We specify the prefix as `time traveller ` and have it generate 10 additional characters.\n", + "Given that we have not trained the network,\n", + "it will generate nonsensical predictions.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "predictCh8(\"time traveller \", 10, net, vocab, Functions.tryGpu(0));" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 41 + }, + "source": [ + "## Gradient Clipping\n", + "\n", + "For a sequence of length $T$,\n", + "we compute the gradients over these $T$ time steps in an iteration, which results in a chain of matrix-products with length $\\mathcal{O}(T)$ during backpropagation.\n", + "As mentioned in :numref:`sec_numerical_stability`, it might result in numerical instability, e.g., the gradients may either explode or vanish, when $T$ is large. Therefore, RNN models often need extra help to stabilize the training.\n", + "\n", + "Generally speaking,\n", + "when solving an optimization problem,\n", + "we take update steps for the model parameter,\n", + "say in the vector form\n", + "$\\mathbf{x}$,\n", + "in the direction of the negative gradient $\\mathbf{g}$ on a minibatch.\n", + "For example,\n", + "with $\\eta > 0$ as the learning rate,\n", + "in one iteration we update\n", + "$\\mathbf{x}$\n", + "as $\\mathbf{x} - \\eta \\mathbf{g}$.\n", + "Let us further assume that the objective function $f$\n", + "is well behaved, say, *Lipschitz continuous* with constant $L$.\n", + "That is to say,\n", + "for any $\\mathbf{x}$ and $\\mathbf{y}$ we have\n", + "\n", + "$$|f(\\mathbf{x}) - f(\\mathbf{y})| \\leq L \\|\\mathbf{x} - \\mathbf{y}\\|.$$\n", + "\n", + "In this case we can safely assume that if we update the parameter vector by $\\eta \\mathbf{g}$, then\n", + "\n", + "$$|f(\\mathbf{x}) - f(\\mathbf{x} - \\eta\\mathbf{g})| \\leq L \\eta\\|\\mathbf{g}\\|,$$\n", + "\n", + "which means that\n", + "we will not observe a change by more than $L \\eta \\|\\mathbf{g}\\|$. This is both a curse and a blessing.\n", + "On the curse side,\n", + "it limits the speed of making progress;\n", + "whereas on the blessing side,\n", + "it limits the extent to which things can go wrong if we move in the wrong direction.\n", + "\n", + "Sometimes the gradients can be quite large and the optimization algorithm may fail to converge. We could address this by reducing the learning rate $\\eta$. But what if we only *rarely* get large gradients? In this case such an approach may appear entirely unwarranted. One popular alternative is to clip the gradient $\\mathbf{g}$ by projecting them back to a ball of a given radius, say $\\theta$ via\n", + "\n", + "$$\\mathbf{g} \\leftarrow \\min\\left(1, \\frac{\\theta}{\\|\\mathbf{g}\\|}\\right) \\mathbf{g}.$$\n", + "\n", + "By doing so we know that the gradient norm never exceeds $\\theta$ and that the\n", + "updated gradient is entirely aligned with the original direction of $\\mathbf{g}$.\n", + "It also has the desirable side-effect of limiting the influence any given\n", + "minibatch (and within it any given sample) can exert on the parameter vector. This\n", + "bestows a certain degree of robustness to the model. Gradient clipping provides\n", + "a quick fix to the gradient exploding. While it does not entirely solve the problem, it is one of the many techniques to alleviate it.\n", + "\n", + "Below we define a function to clip the gradients of\n", + "a model that is implemented from scratch or a model constructed by the high-level APIs.\n", + "Also note that we compute the gradient norm over all the model parameters.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "/**\n", + " * Clip the gradient.\n", + " */\n", + "public void gradClipping(RNNModelScratch net, int theta) {\n", + " double result = 0;\n", + " for (NDArray p : net.params) {\n", + " results += (double) p.getGradient().pow(2).sum().getFloat();\n", + " }\n", + " double norm = Math.sqrt(results);\n", + " if (norm > theta) {\n", + " for (NDArray param : net.params) {\n", + " NDArray gradient = param.getGradient();\n", + " gradient.set(new NDIndex(\":\"), theta / norm);\n", + " }\n", + " }\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 45 + }, + "source": [ + "## Training\n", + "\n", + "Before training the model,\n", + "let us define a function to train the model in one epoch. It differs from how we train the model of :numref:`sec_softmax_scratch` in three places:\n", + "\n", + "1. Different sampling methods for sequential data (random sampling and sequential partitioning) will result in differences in the initialization of hidden states.\n", + "1. We clip the gradients before updating the model parameters. This ensures that the model does not diverge even when gradients blow up at some point during the training process.\n", + "1. We use perplexity to evaluate the model. As discussed in :numref:`subsec_perplexity`, this ensures that sequences of different length are comparable.\n", + "\n", + "\n", + "Specifically,\n", + "when sequential partitioning is used, we initialize the hidden state only at the beginning of each epoch.\n", + "Since the $i^\\mathrm{th}$ subsequence example in the next minibatch is adjacent to the current $i^\\mathrm{th}$ subsequence example,\n", + "the hidden state at the end of the current minibatch\n", + "will be\n", + "used to initialize\n", + "the hidden state at the beginning of the next minibatch.\n", + "In this way,\n", + "historical information of the sequence\n", + "stored in the hidden state\n", + "might flow over\n", + "adjacent subsequences within an epoch.\n", + "However, the computation of the hidden state\n", + "at any point depends on all the previous minibatches\n", + "in the same epoch,\n", + "which complicates the gradient computation.\n", + "To reduce computational cost,\n", + "we detach the gradient before processing any minibatch\n", + "so that the gradient computation of the hidden state\n", + "is always limited to\n", + "the time steps in one minibatch. \n", + "\n", + "When using the random sampling,\n", + "we need to re-initialize the hidden state for each iteration since each example is sampled with a random position.\n", + "Same as the `train_epoch_ch3` function in :numref:`sec_softmax_scratch`,\n", + "`updater` is a general function\n", + "to update the model parameters.\n", + "It can be either the `d2l.sgd` function implemented from scratch or the built-in optimization function in\n", + "a deep learning framework.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "/**\n", + " * Train a model within one epoch.\n", + " */\n", + "public Pair trainEpochCh8(RNNModelScratch net, List trainIter, Loss loss,\n", + " voidFunction updater, \n", + " Device device, boolean useRandomIter) {\n", + " try {\n", + " NDArray state = null;\n", + " StopWatch watch = new StopWatch();\n", + " watch.start();\n", + " Accumulator metric = new Accumulator(2); // Sum of training loss, no. of tokens\n", + " for (NDList pair : trainIter) {\n", + " NDArray X = pair.get(0);\n", + " NDArray Y = pair.get(1);\n", + " if (state == null || useRandomIter) {\n", + " // Initialize `state` when either it is the first iteration or\n", + " // using random sampling\n", + " state = net.beginState((int) X.getShape().getShape()[0], device);\n", + " } else {\n", + " state.detach();\n", + " }\n", + " NDArray y = Y.transpose().reshape(new Shape(-1));\n", + " X = X.toDevice(device, false);\n", + " y = y.toDevice(device, false);\n", + " try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n", + " Pair pairResult = net.call(X, state);\n", + " NDArray yHat = pairResult.getValue();\n", + " state = pairResult.getValue();\n", + " NDArray l = loss.evaluate(new NDList(yHat), new NDList(y)).mean();\n", + " gc.backward(l);\n", + " gradClipping(net, 1);\n", + " updater.apply(1); // Since the `mean` function has been invoked\n", + " metric.add(new float[] {1 * y.size(), y.size()});\n", + " }\n", + " }\n", + " }catch (Exception e) {e.printStackTrace();}\n", + " return new Pair(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop());\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "origin_pos": 42, + "tab": [ + "mxnet" + ] + }, + "outputs": [], + "source": [ + "def grad_clipping(net, theta): #@save\n", + " \"\"\"Clip the gradient.\"\"\"\n", + " if isinstance(net, gluon.Block):\n", + " params = [p.data() for p in net.collect_params().values()]\n", + " else:\n", + " params = net.params\n", + " norm = math.sqrt(sum((p.grad ** 2).sum() for p in params))\n", + " if norm > theta:\n", + " for param in params:\n", + " param.grad[:] *= theta / norm" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 49 + }, + "source": [ + "The training function supports\n", + "an RNN model implemented\n", + "either from scratch\n", + "or using high-level APIs.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "/*\n", + " ** Train a model.\n", + " */\n", + "public void trainCh8(RNNModelScratch net, List trainIter, Vocab vocab,\n", + " int lr, int numEpochs, Device device, boolean useRandomIter) {\n", + " try {\n", + " SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss();\n", + " Animator animator = new Animator();\n", + " // Initialize\n", + " voidFunction updater = (batchSize) -> Training.sgd(net.params, lr, batchSize);\n", + " Function predict = (prefix) -> \n", + " predictCh8(prefix, 50, net, vocab, device);\n", + " // Train and predict\n", + " double ppl = 0.0;\n", + " double speed = 0.0;\n", + " for (int epoch = 0; epoch < numEpochs; epoch++) {\n", + " Pair pair = \n", + " trainEpochCh8(net, trainIter, loss, updater, device, useRandomIter);\n", + " ppl = pair.getKey();\n", + " speed = pair.getValue();\n", + " if ((epoch + 1) % 10 == 0) {\n", + " animator.add(epoch + 1, (float) ppl, \"\");\n", + " }\n", + " }\n", + " System.out.format(\"perplexity: %.1d, %.1d tokens/sec on %s%n\", \n", + " ppl, speed, device.toString());\n", + " System.out.println(predict.apply(\"time traveller\"));\n", + " System.out.println(predict.apply(\"traveller\"));\n", + " }catch (Exception e) {e.printStackTrace();}\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 53 + }, + "source": [ + "Now we can train the RNN model.\n", + "Since we only use 10000 tokens in the dataset, the model needs more epochs to converge better.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "int numEpochs = 500;\n", + "int lr = 1;\n", + "trainCh8(net, trainIter, vocab, lr, numEpochs, Functions.tryGpu(0), false);" + ] + }, + { + "cell_type": "raw", + "metadata": { + "origin_pos": 54, + "tab": [ + "mxnet" + ] + }, + "source": [ + "num_epochs, lr = 500, 1\n", + "train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 56 + }, + "source": [ + "Finally,\n", + "let us check the results of using the random sampling method.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainCh8(net, trainIter, vocab, lr, numEpochs, Functions.tryGpu(0), true);" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 59 + }, + "source": [ + "While implementing the above RNN model from scratch is instructive, it is not convenient.\n", + "In the next section we will see how to improve the RNN model,\n", + "such as how to make it easier to implement\n", + "and make it run faster.\n", + "\n", + "\n", + "## Summary\n", + "\n", + "* We can train an RNN-based character-level language model to generate text following the user-provided text prefix.\n", + "* A simple RNN language model consists of input encoding, RNN modeling, and output generation.\n", + "* RNN models need state initialization for training, though random sampling and sequential partitioning use different ways.\n", + "* When using sequential partitioning, we need to detach the gradient to reduce computational cost.\n", + "* A warm-up period allows a model to update itself (e.g., obtain a better hidden state than its initialized value) before making any prediction.\n", + "* Gradient clipping prevents gradient explosion, but it cannot fix vanishing gradients.\n", + "\n", + "\n", + "## Exercises\n", + "\n", + "1. Show that one-hot encoding is equivalent to picking a different embedding for each object.\n", + "1. Adjust the hyperparameters (e.g., number of epochs, number of hidden units, number of time steps in a minibatch, and learning rate) to improve the perplexity.\n", + " * How low can you go?\n", + " * Replace one-hot encoding with learnable embeddings. Does this lead to better performance?\n", + " * How well will it work on other books by H. G. Wells, e.g., [*The War of the Worlds*](http://www.gutenberg.org/ebooks/36)?\n", + "1. Modify the prediction function such as to use sampling rather than picking the most likely next character.\n", + " * What happens?\n", + " * Bias the model towards more likely outputs, e.g., by sampling from $q(x_t \\mid x_{t-1}, \\ldots, x_1) \\propto P(x_t \\mid x_{t-1}, \\ldots, x_1)^\\alpha$ for $\\alpha > 1$.\n", + "1. Run the code in this section without clipping the gradient. What happens?\n", + "1. Change sequential partitioning so that it does not separate hidden states from the computational graph. Does the running time change? How about the perplexity?\n", + "1. Replace the activation function used in this section with ReLU and repeat the experiments in this section. Do we still need gradient clipping? Why?\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "origin_pos": 54, + "tab": [ + "mxnet" + ] + }, + "outputs": [], + "source": [ + "num_epochs, lr = 500, 1\n", + "train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "origin_pos": 57, + "tab": [ + "mxnet" + ] + }, + "outputs": [], + "source": [ + "train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu(),\n", + " use_random_iter=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Java", + "language": "java", + "name": "java" + }, + "language_info": { + "codemirror_mode": "java", + "file_extension": ".jshell", + "mimetype": "text/x-java-source", + "name": "Java", + "pygments_lexer": "java", + "version": "11.0.10+9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/utils/Functions.java b/utils/Functions.java index 4a74df50..0b848e52 100644 --- a/utils/Functions.java +++ b/utils/Functions.java @@ -20,4 +20,11 @@ public static double[] floatToDoubleArray(float[] x) { } return ret; } + + /** + * Return the i'th GPU if it exists, otherwise return the CPU + */ + public static Device tryGpu(int i) { + return Device.getGpuCount() >= i + 1 ? Device.gpu(i) : Device.cpu(); + } } \ No newline at end of file From bac7d3b824a3a6eb1781283d3a93b7ad143fe8ca Mon Sep 17 00:00:00 2001 From: markbookk Date: Mon, 1 Mar 2021 15:47:28 -0400 Subject: [PATCH 02/20] Finishing section 8.5, still need to test --- .../rnn-scratch.ipynb | 262 +++++------------- 1 file changed, 76 insertions(+), 186 deletions(-) diff --git a/chapter_recurrent-neural-networks/rnn-scratch.ipynb b/chapter_recurrent-neural-networks/rnn-scratch.ipynb index fef28fbe..f521e14b 100644 --- a/chapter_recurrent-neural-networks/rnn-scratch.ipynb +++ b/chapter_recurrent-neural-networks/rnn-scratch.ipynb @@ -36,54 +36,7 @@ "// See https://github.com/awslabs/djl/blob/master/mxnet/mxnet-engine/README.md\n", "// MXNet \n", "%maven ai.djl.mxnet:mxnet-engine:0.11.0-SNAPSHOT\n", - "%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport\n", - "\n", - "// Tensorflow\n", - "// %maven org.bytedeco:javacpp:1.5.4" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %%loadFromPOM\n", - "\n", - "// \n", - "// org.bytedeco\n", - "// javacv-platform\n", - "// 1.5.4\n", - "// \n", - "\n", - "// \n", - "// com.google.protobuf\n", - "// protobuf-java\n", - "// 3.8.0\n", - "// " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "// // MXNET\n", - "// // List addedJars = %jars /Users/nivesmn/Documents/projects/djl/api/build/libs/ /Users/nivesmn/Documents/projects/djl/basicdataset/build/libs/ /Users/nivesmn/Documents/projects/djl/mxnet/mxnet-engine/build/libs/ /Users/nivesmn/Documents/projects/djl/mxnet/native/build/libs/\n", - "// List addedJars = %jars /Users/nivesmn/Documents/projects/djl/api/build/libs/*SNAPSHOT.jar /Users/nivesmn/Documents/projects/djl/basicdataset/build/libs/*SNAPSHOT.jar /Users/nivesmn/Documents/projects/djl/mxnet/mxnet-engine/build/libs/*SNAPSHOT.jar /Users/nivesmn/Documents/projects/djl/mxnet/native/build/libs/*SNAPSHOT.jar\n", - "\n", - "// // TF\n", - "// // List addedJars = %jars /Users/nivesmn/Documents/projects/djl/api/build/libs/ /Users/nivesmn/Documents/projects/djl/basicdataset/build/libs/ /Users/nivesmn/Documents/projects/djl/tensorflow/tensorflow-engine/build/libs/ /Users/nivesmn/Documents/projects/djl/tensorflow/tensorflow-native/build/libs/ /Users/nivesmn/Documents/projects/djl/tensorflow/tensorflow-api/build/libs/ \n", - "// // List addedJars = %jars /Users/nivesmn/Documents/projects/djl/api/build/libs/*SNAPSHOT.jar /Users/nivesmn/Documents/projects/djl/basicdataset/build/libs/*SNAPSHOT.jar /Users/nivesmn/Documents/projects/djl/tensorflow/tensorflow-engine/build/libs/*SNAPSHOT.jar /Users/nivesmn/Documents/projects/djl/tensorflow/tensorflow-native/build/libs/*SNAPSHOT.jar /Users/nivesmn/Documents/projects/djl/tensorflow/tensorflow-api/build/libs/*SNAPSHOT.jar\n", - "\n", - "// // PT\n", - "// // List addedJars = %jars /Users/nivesmn/Documents/projects/djl/api/build/libs/ /Users/nivesmn/Documents/projects/djl/basicdataset/build/libs/ /Users/nivesmn/Documents/projects/djl/./pytorch/pytorch-engine/build/libs/ /Users/nivesmn/Documents/projects/djl/./pytorch/pytorch-native/build/libs/\n", - " \n", - "// System.out.println(\"Size: \" + addedJars.size());\n", - "// addedJars" + "%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport" ] }, { @@ -253,10 +206,10 @@ "metadata": {}, "outputs": [], "source": [ - "public NDList getParams(int vocabSize, int numHiddens, Device device) {\n", + "public static NDList getParams(int vocabSize, int numHiddens, Device device) {\n", " int numOutputs = vocabSize;\n", " int numInputs = vocabSize;\n", - " \n", + "\n", " // Hidden layer parameters\n", " NDArray W_xh = normal(new Shape(numInputs, numHiddens), device);\n", " NDArray W_hh = normal(new Shape(numHiddens, numHiddens), device);\n", @@ -264,8 +217,7 @@ " // Output layer parameters\n", " NDArray W_hq = normal(new Shape(numHiddens, numOutputs), device);\n", " NDArray b_q = manager.zeros(new Shape(numOutputs), DataType.FLOAT32, device);\n", - " \n", - " \n", + "\n", " // Attach gradients\n", " NDList params = new NDList(W_xh, W_hh, b_h, W_hq, b_q);\n", " for (NDArray param : params) {\n", @@ -273,7 +225,8 @@ " }\n", " return params;\n", "}\n", - "public NDArray normal(Shape shape, Device device) {\n", + "\n", + "public static NDArray normal(Shape shape, Device device) {\n", " return manager.randomNormal(0f, 0.01f, shape, DataType.FLOAT32, device);\n", "}" ] @@ -300,7 +253,7 @@ "metadata": {}, "outputs": [], "source": [ - "public NDArray initRNNState(int batchSize, int numHiddens, Device device) {\n", + "public static NDArray initRNNState(int batchSize, int numHiddens, Device device) {\n", " return manager.zeros(new Shape(batchSize, numHiddens), DataType.FLOAT32, device);\n", "}" ] @@ -332,7 +285,7 @@ "metadata": {}, "outputs": [], "source": [ - "public Pair rnn(NDArray inputs, NDArray state, NDList params) {\n", + "public static Pair rnn(NDArray inputs, NDArray state, NDList params) {\n", " // Shape of `inputs`: (`numSteps`, `batchSize`, `vocabSize`)\n", " NDArray W_xh = params.get(0);\n", " NDArray W_hh = params.get(1);\n", @@ -340,17 +293,15 @@ " NDArray W_hq = params.get(3);\n", " NDArray b_q = params.get(4);\n", " NDArray H = state;\n", - " \n", + "\n", " NDList outputs = new NDList();\n", - " //Shape of `X`: (`batchSize`, `vocabSize`)\n", + " // Shape of `X`: (`batchSize`, `vocabSize`)\n", " NDArray X, Y;\n", - "// System.out.println(\"inputs.getShape().getShape()[0] -> \" + inputs.getShape().getShape()[0]);\n", - " for(int i = 0; i < inputs.getShape().getShape()[0]; i++) {\n", + " for (int i = 0; i < inputs.getShape().getShape()[0]; i++) {\n", " X = inputs.get(new NDIndex(i));\n", " H = (X.dot(W_xh).add(H.dot(W_hh)).add(b_h)).tanh();\n", " Y = H.dot(W_hq).add(b_q);\n", " outputs.add(Y);\n", - "// System.out.println(Y.toDebugString(100, 100, 100, 100));\n", " }\n", " return new Pair(outputs.size() > 1 ? NDArrays.concat(outputs) : outputs.get(0), H);\n", "}" @@ -372,34 +323,33 @@ "metadata": {}, "outputs": [], "source": [ - "/**\n", - " * An RNN Model implemented from scratch. \n", - "*/\n", - "class RNNModelScratch {\n", + "/** An RNN Model implemented from scratch. */\n", + "public class RNNModelScratch {\n", " public int vocabSize;\n", " public int numHiddens;\n", " public NDList params;\n", " public TriFunction initState;\n", " public TriFunction forwardFn;\n", - " \n", - " \n", - " public RNNModelScratch(int vocabSize, int numHiddens, Device device, \n", - " TriFunction getParams, \n", - " TriFunction initRNNState,\n", - " TriFunction forwardFn\n", - " ) {\n", + "\n", + " public RNNModelScratch(\n", + " int vocabSize,\n", + " int numHiddens,\n", + " Device device,\n", + " TriFunction getParams,\n", + " TriFunction initRNNState,\n", + " TriFunction forwardFn) {\n", " this.vocabSize = vocabSize;\n", " this.numHiddens = numHiddens;\n", " this.params = getParams.apply(vocabSize, numHiddens, device);\n", " this.initState = initRNNState;\n", " this.forwardFn = forwardFn;\n", " }\n", - " \n", + "\n", " public Pair call(NDArray X, NDArray state) {\n", " X = X.transpose().oneHot(this.vocabSize);\n", " return this.forwardFn.apply(X, state, this.params);\n", " }\n", - " \n", + "\n", " public NDArray beginState(int batchSize, Device device) {\n", " return this.initState.apply(batchSize, this.numHiddens, device);\n", " }\n", @@ -419,17 +369,21 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": false + "scrolled": true }, "outputs": [], "source": [ "int numHiddens = 512;\n", "TriFunction getParamsFn = (a, b, c) -> getParams(a, b, c);\n", - "TriFunction initRNNStateFn = (a, b, c) -> initRNNState(a, b, c);\n", + "TriFunction initRNNStateFn =\n", + " (a, b, c) -> initRNNState(a, b, c);\n", "TriFunction rnnFn = (a, b, c) -> rnn(a, b, c);\n", "\n", - "RNNModelScratch net = new RNNModelScratch(vocab.length(), numHiddens, Functions.tryGpu(0), \n", - " getParamsFn, initRNNStateFn, rnnFn);\n", + "NDArray X = manager.arange(10).reshape(new Shape(2, 5));\n", + "\n", + "RNNModelScratch net =\n", + " new RNNModelScratch(\n", + " vocab.length(), numHiddens, Functions.tryGpu(0), getParamsFn, initRNNStateFn, rnnFn);\n", "NDArray state = net.beginState((int) X.getShape().getShape()[0], Functions.tryGpu(0));\n", "Pair pairResult = net.call(X.toDevice(Functions.tryGpu(0), false), state);\n", "NDArray Y = pairResult.getKey();\n", @@ -473,22 +427,22 @@ "metadata": {}, "outputs": [], "source": [ - "/**\n", - " * Generate new characters following the `prefix`.\"\"\"\n", - " */\n", - "public String predictCh8(String prefix, int numPreds, \n", - " RNNModelScratch net,\n", - " Vocab vocab, Device device) {\n", + "/** Generate new characters following the `prefix`.\"\"\" */\n", + "public static String predictCh8(\n", + " String prefix, int numPreds, RNNModelScratch net, Vocab vocab, Device device) {\n", " NDArray state = net.beginState(1, device);\n", " List outputs = new ArrayList<>();\n", " outputs.add(vocab.getIdx(\"\" + prefix.charAt(0)));\n", - " SimpleFunction getInput = () -> manager.create(outputs.get(outputs.size()-1))\n", - " .toDevice(device, false).reshape(new Shape(1, 1));\n", + " SimpleFunction getInput =\n", + " () ->\n", + " manager.create(outputs.get(outputs.size() - 1))\n", + " .toDevice(device, false)\n", + " .reshape(new Shape(1, 1));\n", " for (char c : prefix.substring(1).toCharArray()) { // Warm-up period\n", " state = (NDArray) net.call(getInput.apply(), state).getValue();\n", " outputs.add(vocab.getIdx(\"\" + c));\n", " }\n", - " \n", + "\n", " NDArray y;\n", " for (int i = 0; i < numPreds; i++) {\n", " Pair pair = net.call(getInput.apply(), state);\n", @@ -591,15 +545,13 @@ "metadata": {}, "outputs": [], "source": [ - "/**\n", - " * Clip the gradient.\n", - " */\n", - "public void gradClipping(RNNModelScratch net, int theta) {\n", + "/** Clip the gradient. */\n", + "public static void gradClipping(RNNModelScratch net, int theta) {\n", " double result = 0;\n", " for (NDArray p : net.params) {\n", - " results += (double) p.getGradient().pow(2).sum().getFloat();\n", + " result += p.getGradient().pow(2).sum().getFloat();\n", " }\n", - " double norm = Math.sqrt(results);\n", + " double norm = Math.sqrt(result);\n", " if (norm > theta) {\n", " for (NDArray param : net.params) {\n", " NDArray gradient = param.getGradient();\n", @@ -662,13 +614,14 @@ "metadata": {}, "outputs": [], "source": [ - "/**\n", - " * Train a model within one epoch.\n", - " */\n", - "public Pair trainEpochCh8(RNNModelScratch net, List trainIter, Loss loss,\n", - " voidFunction updater, \n", - " Device device, boolean useRandomIter) {\n", - " try {\n", + "/** Train a model within one epoch. */\n", + "public static Pair trainEpochCh8(\n", + " RNNModelScratch net,\n", + " List trainIter,\n", + " Loss loss,\n", + " voidFunction updater,\n", + " Device device,\n", + " boolean useRandomIter) {\n", " NDArray state = null;\n", " StopWatch watch = new StopWatch();\n", " watch.start();\n", @@ -688,43 +641,19 @@ " y = y.toDevice(device, false);\n", " try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n", " Pair pairResult = net.call(X, state);\n", - " NDArray yHat = pairResult.getValue();\n", + " NDArray yHat = pairResult.getKey();\n", " state = pairResult.getValue();\n", - " NDArray l = loss.evaluate(new NDList(yHat), new NDList(y)).mean();\n", + " NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean();\n", " gc.backward(l);\n", - " gradClipping(net, 1);\n", - " updater.apply(1); // Since the `mean` function has been invoked\n", - " metric.add(new float[] {1 * y.size(), y.size()});\n", + " metric.add(new float[] {l.getFloat() * y.size(), y.size()});\n", " }\n", + " gradClipping(net, 1);\n", + " updater.apply(1); // Since the `mean` function has been invoked\n", " }\n", - " }catch (Exception e) {e.printStackTrace();}\n", " return new Pair(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop());\n", "}" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "origin_pos": 42, - "tab": [ - "mxnet" - ] - }, - "outputs": [], - "source": [ - "def grad_clipping(net, theta): #@save\n", - " \"\"\"Clip the gradient.\"\"\"\n", - " if isinstance(net, gluon.Block):\n", - " params = [p.data() for p in net.collect_params().values()]\n", - " else:\n", - " params = net.params\n", - " norm = math.sqrt(sum((p.grad ** 2).sum() for p in params))\n", - " if norm > theta:\n", - " for param in params:\n", - " param.grad[:] *= theta / norm" - ] - }, { "cell_type": "markdown", "metadata": { @@ -743,35 +672,37 @@ "metadata": {}, "outputs": [], "source": [ - "/*\n", - " ** Train a model.\n", - " */\n", - "public void trainCh8(RNNModelScratch net, List trainIter, Vocab vocab,\n", - " int lr, int numEpochs, Device device, boolean useRandomIter) {\n", - " try {\n", + "/** Train a model. */\n", + "public static void trainCh8(\n", + " RNNModelScratch net,\n", + " List trainIter,\n", + " Vocab vocab,\n", + " int lr,\n", + " int numEpochs,\n", + " Device device,\n", + " boolean useRandomIter) {\n", " SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss();\n", " Animator animator = new Animator();\n", " // Initialize\n", " voidFunction updater = (batchSize) -> Training.sgd(net.params, lr, batchSize);\n", - " Function predict = (prefix) -> \n", - " predictCh8(prefix, 50, net, vocab, device);\n", + " Function predict = (prefix) -> predictCh8(prefix, 50, net, vocab, device);\n", " // Train and predict\n", " double ppl = 0.0;\n", " double speed = 0.0;\n", " for (int epoch = 0; epoch < numEpochs; epoch++) {\n", - " Pair pair = \n", - " trainEpochCh8(net, trainIter, loss, updater, device, useRandomIter);\n", + " Pair pair =\n", + " trainEpochCh8(net, trainIter, loss, updater, device, useRandomIter);\n", " ppl = pair.getKey();\n", " speed = pair.getValue();\n", " if ((epoch + 1) % 10 == 0) {\n", - " animator.add(epoch + 1, (float) ppl, \"\");\n", + " animator.add(epoch + 1, (float) ppl, \"ppl\");\n", + " animator.show();\n", " }\n", " }\n", - " System.out.format(\"perplexity: %.1d, %.1d tokens/sec on %s%n\", \n", - " ppl, speed, device.toString());\n", + " System.out.format(\n", + " \"perplexity: %.1d, %.1d tokens/sec on %s%n\", ppl, speed, device.toString());\n", " System.out.println(predict.apply(\"time traveller\"));\n", " System.out.println(predict.apply(\"traveller\"));\n", - " }catch (Exception e) {e.printStackTrace();}\n", "}" ] }, @@ -788,7 +719,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ "int numEpochs = 500;\n", @@ -796,19 +729,6 @@ "trainCh8(net, trainIter, vocab, lr, numEpochs, Functions.tryGpu(0), false);" ] }, - { - "cell_type": "raw", - "metadata": { - "origin_pos": 54, - "tab": [ - "mxnet" - ] - }, - "source": [ - "num_epochs, lr = 500, 1\n", - "train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())" - ] - }, { "cell_type": "markdown", "metadata": { @@ -864,36 +784,6 @@ "1. Change sequential partitioning so that it does not separate hidden states from the computational graph. Does the running time change? How about the perplexity?\n", "1. Replace the activation function used in this section with ReLU and repeat the experiments in this section. Do we still need gradient clipping? Why?\n" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "origin_pos": 54, - "tab": [ - "mxnet" - ] - }, - "outputs": [], - "source": [ - "num_epochs, lr = 500, 1\n", - "train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "origin_pos": 57, - "tab": [ - "mxnet" - ] - }, - "outputs": [], - "source": [ - "train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu(),\n", - " use_random_iter=True)" - ] } ], "metadata": { From 3d02ce465895cca7429b2d57eb2d14f0bdf63548 Mon Sep 17 00:00:00 2001 From: markbookk Date: Mon, 1 Mar 2021 16:10:01 -0400 Subject: [PATCH 03/20] Fixing bug in TimeMachineUtils.java, was using instead of on seqDataIterRandom --- utils/TimeMachineUtils.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/utils/TimeMachineUtils.java b/utils/TimeMachineUtils.java index cc6e5659..9677f3d0 100644 --- a/utils/TimeMachineUtils.java +++ b/utils/TimeMachineUtils.java @@ -217,12 +217,12 @@ public Pair, Vocab> loadDataTimeMachine(int batchSize, int num // subsequences List initialIndicesPerBatch = initialIndices.subList(i, i + batchSize); - NDArray xNDArray = manager.create(new Shape(initialIndices.size(), numSteps), DataType.INT32); - NDArray yNDArray = manager.create(new Shape(initialIndices.size(), numSteps), DataType.INT32); - for (int j = 0; j < initialIndices.size(); j++) { - ArrayList X = data(initialIndices.get(j), corpus, numSteps); + NDArray xNDArray = manager.create(new Shape(initialIndicesPerBatch.size(), numSteps), DataType.INT32); + NDArray yNDArray = manager.create(new Shape(initialIndicesPerBatch.size(), numSteps), DataType.INT32); + for (int j = 0; j < initialIndicesPerBatch.size(); j++) { + ArrayList X = data(initialIndicesPerBatch.get(j), corpus, numSteps); xNDArray.set(new NDIndex(j), manager.create(X.stream().mapToInt(Integer::intValue).toArray())); - ArrayList Y = data(initialIndices.get(j)+1, corpus, numSteps); + ArrayList Y = data(initialIndicesPerBatch.get(j)+1, corpus, numSteps); yNDArray.set(new NDIndex(j), manager.create(Y.stream().mapToInt(Integer::intValue).toArray())); } NDList pair = new NDList(); From e52268bc3186478cbe4027795dcd099be5f3538e Mon Sep 17 00:00:00 2001 From: markbookk Date: Wed, 3 Mar 2021 10:29:13 -0400 Subject: [PATCH 04/20] Fixing memory problem by creating subNDManager and using it on all (or most) temp operations --- .../rnn-scratch.ipynb | 77 +++++++++++-------- utils/Functions.java | 25 ++++++ utils/Training.java | 12 +++ 3 files changed, 83 insertions(+), 31 deletions(-) diff --git a/chapter_recurrent-neural-networks/rnn-scratch.ipynb b/chapter_recurrent-neural-networks/rnn-scratch.ipynb index f521e14b..52737332 100644 --- a/chapter_recurrent-neural-networks/rnn-scratch.ipynb +++ b/chapter_recurrent-neural-networks/rnn-scratch.ipynb @@ -96,6 +96,11 @@ "@FunctionalInterface\n", "public interface voidFunction {\n", " public void apply(T t);\n", + "}\n", + "\n", + "@FunctionalInterface\n", + "public interface voidTwoFunction {\n", + " public void apply(T t, U u);\n", "}" ] }, @@ -148,11 +153,11 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": true + "scrolled": false }, "outputs": [], "source": [ - "manager.create(new int[] {0, 2}).oneHot(vocab.length())" + "manager.create(new int[] {0, 2}).oneHot(vocab.length()).toDevice(Functions.tryGpu(0), false)" ] }, { @@ -546,10 +551,12 @@ "outputs": [], "source": [ "/** Clip the gradient. */\n", - "public static void gradClipping(RNNModelScratch net, int theta) {\n", + "public static void gradClipping(RNNModelScratch net, int theta, NDManager manager) {\n", " double result = 0;\n", " for (NDArray p : net.params) {\n", - " result += p.getGradient().pow(2).sum().getFloat();\n", + " NDArray gradient = p.getGradient();\n", + " gradient.attach(manager);\n", + " result += gradient.pow(2).sum().getFloat();\n", " }\n", " double norm = Math.sqrt(result);\n", " if (norm > theta) {\n", @@ -619,36 +626,42 @@ " RNNModelScratch net,\n", " List trainIter,\n", " Loss loss,\n", - " voidFunction updater,\n", + " voidTwoFunction updater,\n", " Device device,\n", " boolean useRandomIter) {\n", - " NDArray state = null;\n", " StopWatch watch = new StopWatch();\n", " watch.start();\n", " Accumulator metric = new Accumulator(2); // Sum of training loss, no. of tokens\n", - " for (NDList pair : trainIter) {\n", - " NDArray X = pair.get(0);\n", - " NDArray Y = pair.get(1);\n", - " if (state == null || useRandomIter) {\n", - " // Initialize `state` when either it is the first iteration or\n", - " // using random sampling\n", - " state = net.beginState((int) X.getShape().getShape()[0], device);\n", - " } else {\n", - " state.detach();\n", - " }\n", - " NDArray y = Y.transpose().reshape(new Shape(-1));\n", - " X = X.toDevice(device, false);\n", - " y = y.toDevice(device, false);\n", - " try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n", - " Pair pairResult = net.call(X, state);\n", - " NDArray yHat = pairResult.getKey();\n", - " state = pairResult.getValue();\n", - " NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean();\n", - " gc.backward(l);\n", - " metric.add(new float[] {l.getFloat() * y.size(), y.size()});\n", + " try (NDManager childManager = manager.newSubManager()) {\n", + " NDArray state = null;\n", + " for (NDList pair : trainIter) {\n", + " NDArray X = pair.get(0).toDevice(Functions.tryGpu(0), true);\n", + " X.attach(childManager);\n", + " NDArray Y = pair.get(1).toDevice(Functions.tryGpu(0), true);\n", + " Y.attach(childManager);\n", + " if (state == null || useRandomIter) {\n", + " // Initialize `state` when either it is the first iteration or\n", + " // using random sampling\n", + " state = net.beginState((int) X.getShape().getShape()[0], device);\n", + " } else {\n", + " state = state.stopGradient();\n", + " }\n", + " state.attach(childManager);\n", + "\n", + " NDArray y = Y.transpose().reshape(new Shape(-1));\n", + " X = X.toDevice(device, false);\n", + " y = y.toDevice(device, false);\n", + " try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n", + " Pair pairResult = net.call(X, state);\n", + " NDArray yHat = pairResult.getKey();\n", + " state = pairResult.getValue();\n", + " NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean();\n", + " gc.backward(l);\n", + " metric.add(new float[] {l.getFloat() * y.size(), y.size()});\n", + " }\n", + " gradClipping(net, 1, childManager);\n", + " updater.apply(1, childManager); // Since the `mean` function has been invoked\n", " }\n", - " gradClipping(net, 1);\n", - " updater.apply(1); // Since the `mean` function has been invoked\n", " }\n", " return new Pair(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop());\n", "}" @@ -684,7 +697,8 @@ " SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss();\n", " Animator animator = new Animator();\n", " // Initialize\n", - " voidFunction updater = (batchSize) -> Training.sgd(net.params, lr, batchSize);\n", + " voidTwoFunction updater =\n", + " (batchSize, subManager) -> Training.sgd(net.params, lr, batchSize, subManager);\n", " Function predict = (prefix) -> predictCh8(prefix, 50, net, vocab, device);\n", " // Train and predict\n", " double ppl = 0.0;\n", @@ -700,8 +714,9 @@ " }\n", " }\n", " System.out.format(\n", - " \"perplexity: %.1d, %.1d tokens/sec on %s%n\", ppl, speed, device.toString());\n", + " \"perplexity: %.1f, %.1f tokens/sec on %s%n\", ppl, speed, device.toString());\n", " System.out.println(predict.apply(\"time traveller\"));\n", + " \n", " System.out.println(predict.apply(\"traveller\"));\n", "}" ] @@ -724,7 +739,7 @@ }, "outputs": [], "source": [ - "int numEpochs = 500;\n", + "int numEpochs = 100;\n", "int lr = 1;\n", "trainCh8(net, trainIter, vocab, lr, numEpochs, Functions.tryGpu(0), false);" ] diff --git a/utils/Functions.java b/utils/Functions.java index 0b848e52..53133c1e 100644 --- a/utils/Functions.java +++ b/utils/Functions.java @@ -27,4 +27,29 @@ public static double[] floatToDoubleArray(float[] x) { public static Device tryGpu(int i) { return Device.getGpuCount() >= i + 1 ? Device.gpu(i) : Device.cpu(); } + + @FunctionalInterface + public interface TriFunction { + public W apply(T t, U u, V v); + } + + @FunctionalInterface + public interface QuadFunction { + public R apply(T t, U u, V v, W w); + } + + @FunctionalInterface + public interface SimpleFunction { + public T apply(); + } + + @FunctionalInterface + public interface voidFunction { + public void apply(T t); + } + + @FunctionalInterface + public interface voidTwoFunction { + public void apply(T t, U u); + } } \ No newline at end of file diff --git a/utils/Training.java b/utils/Training.java index 1cbae345..343552cf 100644 --- a/utils/Training.java +++ b/utils/Training.java @@ -33,6 +33,18 @@ public static void sgd(NDList params, float lr, int batchSize) { } } + /** Allow to do gradient calculations on subManager **/ + public static void sgd(NDList params, float lr, int batchSize, NDManager subManager) { + for (int i = 0; i < params.size(); i++) { + NDArray param = params.get(i); + // Update param in place. + // param = param - param.gradient * lr / batchSize + NDArray gradient = param.getGradient(); + gradient.attach(subManager); + param.subi(gradient.mul(lr).div(batchSize)); + } + } + public static float accuracy(NDArray yHat, NDArray y) { // Check size of 1st dimension greater than 1 // to see if we have multiple samples From 0e755372fe19aa65c990b1a051a30e81b7229c48 Mon Sep 17 00:00:00 2001 From: markbookk Date: Wed, 3 Mar 2021 11:30:23 -0400 Subject: [PATCH 05/20] Fixing bug in gradClipping --- chapter_recurrent-neural-networks/rnn-scratch.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chapter_recurrent-neural-networks/rnn-scratch.ipynb b/chapter_recurrent-neural-networks/rnn-scratch.ipynb index 52737332..5d0a73a4 100644 --- a/chapter_recurrent-neural-networks/rnn-scratch.ipynb +++ b/chapter_recurrent-neural-networks/rnn-scratch.ipynb @@ -562,7 +562,7 @@ " if (norm > theta) {\n", " for (NDArray param : net.params) {\n", " NDArray gradient = param.getGradient();\n", - " gradient.set(new NDIndex(\":\"), theta / norm);\n", + " gradient.mul(theta / norm);\n", " }\n", " }\n", "}" From 489061b538b0256163033e85c6fdb021b896bd7f Mon Sep 17 00:00:00 2001 From: markbookk Date: Wed, 3 Mar 2021 14:24:17 -0400 Subject: [PATCH 06/20] Fixing typo on previous commit --- chapter_recurrent-neural-networks/rnn-scratch.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chapter_recurrent-neural-networks/rnn-scratch.ipynb b/chapter_recurrent-neural-networks/rnn-scratch.ipynb index 5d0a73a4..4019113a 100644 --- a/chapter_recurrent-neural-networks/rnn-scratch.ipynb +++ b/chapter_recurrent-neural-networks/rnn-scratch.ipynb @@ -562,7 +562,7 @@ " if (norm > theta) {\n", " for (NDArray param : net.params) {\n", " NDArray gradient = param.getGradient();\n", - " gradient.mul(theta / norm);\n", + " gradient.muli(theta / norm);\n", " }\n", " }\n", "}" From 6605d4dddbf66869587f94ee578d10e3e2cba7f1 Mon Sep 17 00:00:00 2001 From: markbookk Date: Wed, 3 Mar 2021 16:12:49 -0400 Subject: [PATCH 07/20] Changing epoch amount --- chapter_recurrent-neural-networks/rnn-scratch.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chapter_recurrent-neural-networks/rnn-scratch.ipynb b/chapter_recurrent-neural-networks/rnn-scratch.ipynb index 4019113a..cf417810 100644 --- a/chapter_recurrent-neural-networks/rnn-scratch.ipynb +++ b/chapter_recurrent-neural-networks/rnn-scratch.ipynb @@ -739,7 +739,7 @@ }, "outputs": [], "source": [ - "int numEpochs = 100;\n", + "int numEpochs = 500;\n", "int lr = 1;\n", "trainCh8(net, trainIter, vocab, lr, numEpochs, Functions.tryGpu(0), false);" ] From be03b1e13ae26a19866550166322d7cabc8db832 Mon Sep 17 00:00:00 2001 From: markbookk Date: Mon, 8 Mar 2021 17:07:13 -0400 Subject: [PATCH 08/20] Adding section 8.6: Consise Implemention of RNNs (still debugging result compared to d2l book) --- .../rnn-concise.ipynb | 384 ++++++++++++++++++ 1 file changed, 384 insertions(+) create mode 100644 chapter_recurrent-neural-networks/rnn-concise.ipynb diff --git a/chapter_recurrent-neural-networks/rnn-concise.ipynb b/chapter_recurrent-neural-networks/rnn-concise.ipynb new file mode 100644 index 00000000..ab3d9fd8 --- /dev/null +++ b/chapter_recurrent-neural-networks/rnn-concise.ipynb @@ -0,0 +1,384 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 0 + }, + "source": [ + "# Concise Implementation of Recurrent Neural Networks\n", + ":label:`sec_rnn-concise`\n", + "\n", + "While :numref:`sec_rnn_scratch` was instructive to see how RNNs are implemented,\n", + "this is not convenient or fast.\n", + "This section will show how to implement the same language model more efficiently\n", + "using functions provided by high-level APIs\n", + "of a deep learning framework.\n", + "We begin as before by reading the time machine dataset.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", + "%maven org.slf4j:slf4j-api:1.7.26\n", + "%maven org.slf4j:slf4j-simple:1.7.26\n", + " \n", + "%maven ai.djl:api:0.11.0-SNAPSHOT\n", + "%maven ai.djl:basicdataset:0.11.0-SNAPSHOT\n", + "\n", + "// See https://github.com/awslabs/djl/blob/master/mxnet/mxnet-engine/README.md\n", + "// MXNet \n", + "%maven ai.djl.mxnet:mxnet-engine:0.11.0-SNAPSHOT\n", + "%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load ../utils/plot-utils\n", + "// %load ../utils/Functions.java\n", + "%load ../utils/PlotUtils.java\n", + "// %load ../utils/TimeMachineUtils.java\n", + "// %load ../utils/StopWatch.java\n", + "// %load ../utils/Accumulator.java\n", + "%load ../utils/Animator.java\n", + "// %load ../utils/Training.java\n", + "\n", + "%load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Functions.java\n", + "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/PlotUtils.java\n", + "%load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/TimeMachine.java\n", + "%load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/StopWatch.java\n", + "%load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Accumulator.java\n", + "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Animator.java\n", + "%load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Training.java\n", + "%load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/RNNModel.java\n", + "%load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/RNNModelScratch.java" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ai.djl.Device;\n", + "import ai.djl.ndarray.NDArray;\n", + "import ai.djl.ndarray.NDList;\n", + "import ai.djl.ndarray.NDManager;\n", + "import ai.djl.ndarray.types.DataType;\n", + "import ai.djl.ndarray.types.Shape;\n", + "import ai.djl.nn.recurrent.RNN;\n", + "import ai.djl.nn.AbstractBlock;\n", + "import ai.djl.nn.core.Linear;\n", + "import ai.djl.training.ParameterStore;\n", + "import ai.djl.util.Pair;\n", + "import ai.djl.util.PairList;\n", + "\n", + "import java.io.IOException;\n", + "import java.util.ArrayList;" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "NDManager manager = NDManager.newBaseManager(Functions.tryGpu(0));" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "int batchSize = 32;\n", + "int numSteps = 35;\n", + "\n", + "Pair, Vocab> timeMachine =\n", + " loadDataTimeMachine(batchSize, numSteps, false, 10000);\n", + "ArrayList trainIter = timeMachine.getKey();\n", + "Vocab vocab = timeMachine.getValue();" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 3 + }, + "source": [ + "## Defining the Model\n", + "\n", + "High-level APIs provide implementations of recurrent neural networks.\n", + "We construct the recurrent neural network layer `rnn_layer` with a single hidden layer and 256 hidden units.\n", + "In fact, we have not even discussed yet what it means to have multiple layers---this will happen in :numref:`sec_deep_rnn`.\n", + "For now, suffice it to say that multiple layers simply amount to the output of one layer of RNN being used as the input for the next layer of RNN.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "int numHiddens = 256;\n", + "RNN rnnLayer = RNN.builder().setNumLayers(1)\n", + " .setStateSize(numHiddens).optReturnState(true).optBatchFirst(false).build();" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 6, + "tab": [ + "mxnet" + ] + }, + "source": [ + "Initializing the hidden state is straightforward.\n", + "We invoke the member function `beginState` _(In DJL we don't have to run `beginState` to later specify the resulting state the first time we run `forward`, as this logic is ran by DJL the first time we do `forward` but we will create it here for demonstration purposes)_.\n", + "This returns a list (`state`)\n", + "that contains\n", + "an initial hidden state\n", + "for each example in the minibatch,\n", + "whose shape is\n", + "(number of hidden layers, batch size, number of hidden units).\n", + "For some models \n", + "to be introduced later \n", + "(e.g., long short-term memory),\n", + "such a list also\n", + "contains other information." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "public static NDList beginState(int batchSize, int numLayers, int numHiddens) {\n", + " return new NDList(manager.zeros(new Shape(numLayers, batchSize, numHiddens)));\n", + "}\n", + "\n", + "NDList state = beginState(batchSize, 1, numHiddens);\n", + "System.out.println(state.size());\n", + "System.out.println(state.get(0).getShape());" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 10 + }, + "source": [ + "With a hidden state and an input,\n", + "we can compute the output with\n", + "the updated hidden state.\n", + "It should be emphasized that\n", + "the \"output\" (`Y`) of `rnnLayer`\n", + "does *not* involve computation of output layers:\n", + "it refers to \n", + "the hidden state at *each* time step,\n", + "and they can be used as the input\n", + "to the subsequent output layer." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 11, + "tab": [ + "mxnet" + ] + }, + "source": [ + "Besides,\n", + "the updated hidden state (`stateNew`) returned by `rnnLayer`\n", + "refers to the hidden state\n", + "at the *last* time step of the minibatch.\n", + "It can be used to initialize the \n", + "hidden state for the next minibatch within an epoch\n", + "in sequential partitioning.\n", + "For multiple hidden layers,\n", + "the hidden state of each layer will be stored\n", + "in this variable (`stateNew`).\n", + "For some models \n", + "to be introduced later \n", + "(e.g., long short-term memory),\n", + "this variable also\n", + "contains other information." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "NDArray X = manager.randomUniform (0, 1,new Shape(numSteps, batchSize, vocab.length()));\n", + "\n", + "NDList forwardOutput = rnnLayer.forward(new ParameterStore(manager, false), new NDList(X, state.get(0)), false);\n", + "NDArray Y = forwardOutput.get(0);\n", + "NDArray stateNew = forwardOutput.get(1);\n", + "\n", + "System.out.println(Y.getShape());\n", + "System.out.println(stateNew.getShape());" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 14 + }, + "source": [ + "Similar to :numref:`sec_rnn_scratch`,\n", + "we define an `RNNModel` class \n", + "for a complete RNN model.\n", + "Note that `rnnLayer` only contains the hidden recurrent layers, we need to create a separate output layer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "public class RNNModel extends AbstractBlock {\n", + "\n", + " private static final byte VERSION = 2;\n", + " private RNN rnnLayer;\n", + " private Linear dense;\n", + " private int vocabSize;\n", + "\n", + " public RNNModel(RNN rnnLayer, int vocabSize) {\n", + " super(VERSION);\n", + " this.rnnLayer = rnnLayer;\n", + " this.addChildBlock(\"rnn\", rnnLayer);\n", + " this.vocabSize = vocabSize;\n", + " this.dense = Linear.builder().setUnits(vocabSize).build();\n", + " }\n", + "\n", + " \n", + " @Override\n", + " protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList params) {\n", + " NDArray X = inputs.get(0).transpose().oneHot(this.vocabSize);\n", + " inputs.set(0, X);\n", + " NDList result = this.rnnLayer.forward(parameterStore, inputs, training);\n", + " NDArray Y = result.get(0);\n", + " NDArray state = result.get(1);\n", + "\n", + " int shapeLength = Y.getShape().getShape().length;\n", + " NDList output = this.dense.forward(parameterStore, new NDList(Y\n", + " .reshape(new Shape(-1, Y.getShape().get(shapeLength-1)))), training);\n", + " return new NDList(output.get(0), state);\n", + " }\n", + " \n", + " @Override\n", + " public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {\n", + " /* rnnLayer is already initialized so we don't have to do anything here, just override it.*/\n", + " return;\n", + " }\n", + "\n", + " /* We won't implement this since we won't be using it but it's required as part of an AbstractBlock */\n", + " @Override\n", + " public Shape[] getOutputShapes(Shape[] inputShapes) {\n", + " return new Shape[0];\n", + " }\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 17 + }, + "source": [ + "## Training and Predicting\n", + "\n", + "Before training the model, let us make a prediction with the a model that has random weights.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Device device = Functions.tryGpu(0);\n", + "RNNModel net = new RNNModel(rnnLayer, vocab.length());\n", + "net.initialize(manager, DataType.FLOAT32, X.getShape());\n", + "String prediction = TimeMachine.predictCh8(\"time traveller\", 10, net, vocab, device, manager);\n", + "System.out.println(prediction);" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 20 + }, + "source": [ + "As is quite obvious, this model does not work at all. Next, we call `trainCh8` with the same hyperparameters defined in :numref:`sec_rnn_scratch` and train our model with high-level APIs.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "int numEpochs = 500;\n", + "int lr = 1;\n", + "TimeMachine.trainCh8((Object) net, trainIter, vocab, lr, numEpochs, device, false, manager);" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "origin_pos": 22 + }, + "source": [ + "Compared with the last section, this model achieves comparable perplexity,\n", + "albeit within a shorter period of time, due to the code being more optimized by\n", + "high-level APIs of the deep learning framework.\n", + "\n", + "\n", + "## Summary\n", + "\n", + "* High-level APIs of the deep learning framework provides an implementation of the RNN layer.\n", + "* The RNN layer of high-level APIs returns an output and an updated hidden state, where the output does not involve output layer computation.\n", + "* Using high-level APIs leads to faster RNN training than using its implementation from scratch.\n", + "\n", + "## Exercises\n", + "\n", + "1. Can you make the RNN model overfit using the high-level APIs?\n", + "1. What happens if you increase the number of hidden layers in the RNN model? Can you make the model work?\n", + "1. Implement the autoregressive model of :numref:`sec_sequence` using an RNN.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Java", + "language": "java", + "name": "java" + }, + "language_info": { + "codemirror_mode": "java", + "file_extension": ".jshell", + "mimetype": "text/x-java-source", + "name": "Java", + "pygments_lexer": "java", + "version": "11.0.10+9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From f4bc81020f6a25db42ba67f479defb3ac13b4603 Mon Sep 17 00:00:00 2001 From: markbookk Date: Tue, 9 Mar 2021 10:45:07 -0400 Subject: [PATCH 09/20] Updating .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 9bea4330..7db97908 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ .DS_Store +testoutput/ +.ipynb_checkpoints/ \ No newline at end of file From d2f9bc90263d0b88dc58df0294dceb02be45cda9 Mon Sep 17 00:00:00 2001 From: markbookk Date: Tue, 9 Mar 2021 11:35:52 -0400 Subject: [PATCH 10/20] Adding finished notebook; adding util functions and chan ging TimeMachineUtils.java --- .../rnn-concise.ipynb | 162 ++++- utils/Functions.java | 28 +- utils/TimeMachineUtils.java | 667 +++++++++++++----- utils/Training.java | 13 +- 4 files changed, 670 insertions(+), 200 deletions(-) diff --git a/chapter_recurrent-neural-networks/rnn-concise.ipynb b/chapter_recurrent-neural-networks/rnn-concise.ipynb index ab3d9fd8..f6d3bc55 100644 --- a/chapter_recurrent-neural-networks/rnn-concise.ipynb +++ b/chapter_recurrent-neural-networks/rnn-concise.ipynb @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -38,33 +38,46 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%load ../utils/plot-utils\n", - "// %load ../utils/Functions.java\n", + "%load ../utils/Functions.java\n", "%load ../utils/PlotUtils.java\n", - "// %load ../utils/TimeMachineUtils.java\n", - "// %load ../utils/StopWatch.java\n", - "// %load ../utils/Accumulator.java\n", + "\n", + "%load ../utils/StopWatch.java\n", + "%load ../utils/Accumulator.java\n", "%load ../utils/Animator.java\n", - "// %load ../utils/Training.java\n", + "%load ../utils/Training.java\n", + "\n", "\n", - "%load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Functions.java\n", - "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/PlotUtils.java\n", - "%load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/TimeMachine.java\n", - "%load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/StopWatch.java\n", - "%load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Accumulator.java\n", - "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Animator.java\n", - "%load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Training.java\n", - "%load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/RNNModel.java\n", - "%load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/RNNModelScratch.java" + "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Functions.java\n", + "// // %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/PlotUtils.java\n", + "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/StopWatch.java\n", + "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Accumulator.java\n", + "// // %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Animator.java\n", + "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Training.java\n", + "\n", + "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/RNNModelScratch.java\n", + "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Vocab.java\n", + "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/SeqDataLoader.java\n", + "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/RNNModel.java\n", + "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/TimeMachine.java" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "%load ../utils/TimeMachineUtils.java" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -87,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -96,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -104,7 +117,7 @@ "int numSteps = 35;\n", "\n", "Pair, Vocab> timeMachine =\n", - " loadDataTimeMachine(batchSize, numSteps, false, 10000);\n", + " SeqDataLoader.loadDataTimeMachine(batchSize, numSteps, false, 10000, manager);\n", "ArrayList trainIter = timeMachine.getKey();\n", "Vocab vocab = timeMachine.getValue();" ] @@ -125,7 +138,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -160,9 +173,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n", + "(1, 32, 256)\n" + ] + } + ], "source": [ "public static NDList beginState(int batchSize, int numLayers, int numHiddens) {\n", " return new NDList(manager.zeros(new Shape(numLayers, batchSize, numHiddens)));\n", @@ -219,9 +241,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(35, 32, 256)\n", + "(1, 32, 256)\n" + ] + } + ], "source": [ "NDArray X = manager.randomUniform (0, 1,new Shape(numSteps, batchSize, vocab.length()));\n", "\n", @@ -247,7 +278,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -308,9 +339,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "time travellerddjjjjjdjj\n" + ] + } + ], "source": [ "Device device = Functions.tryGpu(0);\n", "RNNModel net = new RNNModel(rnnLayer, vocab.length());\n", @@ -330,9 +369,72 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Training on: cpu().\n", + "[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Load MXNet Engine Version 1.7.0 in 0.069 ms.\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + "\n" + ], + "text/plain": [ + "tech.tablesaw.plotly.components.Figure@4ed507e1" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "perplexity: 5.6, 69994.2 tokens/sec on cpu()\n", + "time traveller ou there thoe int and thene ta the this thaw ol \n", + "travellertand mh the thme travel er oae sooeee te tou than \n" + ] + } + ], "source": [ "int numEpochs = 500;\n", "int lr = 1;\n", diff --git a/utils/Functions.java b/utils/Functions.java index 53133c1e..a3bf7d16 100644 --- a/utils/Functions.java +++ b/utils/Functions.java @@ -1,3 +1,5 @@ +import ai.djl.Device; + import java.util.function.Function; public class Functions { @@ -10,7 +12,7 @@ public static float[] callFunc(float[] x, Function func) { } return y; } - + // ScatterTrace.builder() does not support float[], // so we must convert to a double array first public static double[] floatToDoubleArray(float[] x) { @@ -21,35 +23,51 @@ public static double[] floatToDoubleArray(float[] x) { return ret; } - /** - * Return the i'th GPU if it exists, otherwise return the CPU - */ + /** Return the i'th GPU if it exists, otherwise return the CPU */ public static Device tryGpu(int i) { return Device.getGpuCount() >= i + 1 ? Device.gpu(i) : Device.cpu(); } + /** + * Helper function to later be able to use lambda. Accepts three types for parameters and one + * for output. + */ @FunctionalInterface public interface TriFunction { public W apply(T t, U u, V v); } + /** + * Helper function to later be able to use lambda. Accepts 4 types for parameters and one + * for output. + */ @FunctionalInterface public interface QuadFunction { public R apply(T t, U u, V v, W w); } + /** + * Helper function to later be able to use lambda. Doesn't have any type for parameters and has one type + * for output. + */ @FunctionalInterface public interface SimpleFunction { public T apply(); } + /** + * Helper function to later be able to use lambda. Accepts one types for parameters and uses void for return. + */ @FunctionalInterface public interface voidFunction { public void apply(T t); } + /** + * Helper function to later be able to use lambda. Accepts two types for parameters and uses void for return. + */ @FunctionalInterface public interface voidTwoFunction { public void apply(T t, U u); } -} \ No newline at end of file +} diff --git a/utils/TimeMachineUtils.java b/utils/TimeMachineUtils.java index 9677f3d0..21024404 100644 --- a/utils/TimeMachineUtils.java +++ b/utils/TimeMachineUtils.java @@ -1,11 +1,24 @@ import ai.djl.ndarray.*; import ai.djl.ndarray.types.*; import ai.djl.ndarray.index.*; -import ai.djl.util.Pair; +import ai.djl.util.*; +import ai.djl.*; +import ai.djl.engine.*; +import ai.djl.nn.*; +import ai.djl.nn.core.*; +import ai.djl.nn.recurrent.*; +import ai.djl.training.*; +import ai.djl.training.loss.*; +import ai.djl.training.tracker.*; +import ai.djl.training.initializer.*; +import ai.djl.training.evaluator.*; +import ai.djl.training.optimizer.*; +import ai.djl.training.listener.*; import java.io.*; import java.net.URL; import java.util.*; +import java.util.function.*; public class Vocab { public int unk; @@ -16,14 +29,16 @@ public class Vocab { public Vocab(String[][] tokens, int minFreq, String[] reservedTokens) { // Sort according to frequencies HashMap counter = countCorpus2D(tokens); - this.tokenFreqs = new ArrayList>(counter.entrySet()); - Collections.sort(tokenFreqs, - new Comparator>() { - public int compare(Map.Entry o1, Map.Entry o2) { - return (o2.getValue()).compareTo(o1.getValue()); - } - }); - + this.tokenFreqs = new ArrayList>(counter.entrySet()); + Collections.sort( + tokenFreqs, + new Comparator>() { + public int compare( + Map.Entry o1, Map.Entry o2) { + return (o2.getValue()).compareTo(o1.getValue()); + } + }); + // The index for the unknown token is 0 this.unk = 0; List uniqTokens = new ArrayList<>(); @@ -34,87 +49,281 @@ public int compare(Map.Entry o1, Map.Entry o2) uniqTokens.add(entry.getKey()); } } - + this.idxToToken = new ArrayList<>(); this.tokenToIdx = new HashMap<>(); for (String token : uniqTokens) { this.idxToToken.add(token); - this.tokenToIdx.put(token, this.idxToToken.size()-1); + this.tokenToIdx.put(token, this.idxToToken.size() - 1); } } - + public int length() { return this.idxToToken.size(); } - + public Integer[] getIdxs(String[] tokens) { List idxs = new ArrayList<>(); for (String token : tokens) { idxs.add(getIdx(token)); } return idxs.toArray(new Integer[0]); - } - + public Integer getIdx(String token) { return this.tokenToIdx.getOrDefault(token, this.unk); } - - -} -/** - * Count token frequencies. - */ -public HashMap countCorpus(String[] tokens) { - - HashMap counter = new HashMap<>(); - if (tokens.length != 0) { - for (String token : tokens) { - counter.put(token, counter.getOrDefault(token, 0)+1); + /** Count token frequencies. */ + public HashMap countCorpus(String[] tokens) { + + HashMap counter = new HashMap<>(); + if (tokens.length != 0) { + for (String token : tokens) { + counter.put(token, counter.getOrDefault(token, 0) + 1); + } } + return counter; + } + + /** Flatten a list of token lists into a list of tokens */ + public HashMap countCorpus2D(String[][] tokens) { + List allTokens = new ArrayList(); + for (int i = 0; i < tokens.length; i++) { + for (int j = 0; j < tokens[i].length; j++) { + if (tokens[i][j] != "") { + allTokens.add(tokens[i][j]); + } + } + } + return countCorpus(allTokens.toArray(new String[0])); } - return counter; } -/** - * Flatten a list of token lists into a list of tokens - */ -public HashMap countCorpus2D(String[][] tokens) { - List allTokens = new ArrayList(); - for (int i = 0; i < tokens.length; i++) { - for (int j = 0; j < tokens[i].length; j++) { - if (tokens[i][j] != "") { - allTokens.add(tokens[i][j]); - } +public class SeqDataLoader implements Iterable { + public ArrayList dataIter; + public List corpus; + public Vocab vocab; + public int batchSize; + public int numSteps; + + /* An iterator to load sequence data. */ + @SuppressWarnings("unchecked") + public SeqDataLoader( + int batchSize, int numSteps, boolean useRandomIter, int maxTokens, NDManager manager) + throws IOException, Exception { + Pair, Vocab> corpusVocabPair = TimeMachine.loadCorpusTimeMachine(maxTokens); + this.corpus = corpusVocabPair.getKey(); + this.vocab = corpusVocabPair.getValue(); + + this.batchSize = batchSize; + this.numSteps = numSteps; + if (useRandomIter) { + dataIter = seqDataIterRandom(corpus, batchSize, numSteps, manager); + } else { + dataIter = seqDataIterSequential(corpus, batchSize, numSteps, manager); } } - return countCorpus(allTokens.toArray(new String[0])); + + @Override + public Iterator iterator() { + return dataIter.iterator(); + } + + /** Return the iterator and the vocabulary of the time machine dataset. */ + public static Pair, Vocab> loadDataTimeMachine( + int batchSize, int numSteps, boolean useRandomIter, int maxTokens, NDManager manager) + throws IOException, Exception { + + SeqDataLoader seqData = + new SeqDataLoader(batchSize, numSteps, useRandomIter, maxTokens, manager); + return new Pair(seqData.dataIter, seqData.vocab); // ArrayList, Vocab + } + + /** Generate a minibatch of subsequences using random sampling. */ + public ArrayList seqDataIterRandom( + List corpus, int batchSize, int numSteps, NDManager manager) { + // Start with a random offset (inclusive of `numSteps - 1`) to partition a + // sequence + corpus = corpus.subList(new Random().nextInt(numSteps - 1), corpus.size()); + // Subtract 1 since we need to account for labels + int numSubseqs = (corpus.size() - 1) / numSteps; + // The starting indices for subsequences of length `numSteps` + List initialIndices = new ArrayList<>(); + for (int i = 0; i < numSubseqs * numSteps; i += numSteps) { + initialIndices.add(i); + } + // In random sampling, the subsequences from two adjacent random + // minibatches during iteration are not necessarily adjacent on the + // original sequence + Collections.shuffle(initialIndices); + + int numBatches = numSubseqs / batchSize; + + ArrayList pairs = new ArrayList(); + for (int i = 0; i < batchSize * numBatches; i += batchSize) { + // Here, `initialIndices` contains randomized starting indices for + // subsequences + List initialIndicesPerBatch = initialIndices.subList(i, i + batchSize); + + NDArray xNDArray = + manager.create( + new Shape(initialIndicesPerBatch.size(), numSteps), DataType.FLOAT32); + NDArray yNDArray = + manager.create( + new Shape(initialIndicesPerBatch.size(), numSteps), DataType.FLOAT32); + for (int j = 0; j < initialIndicesPerBatch.size(); j++) { + ArrayList X = data(initialIndicesPerBatch.get(j), corpus, numSteps); + xNDArray.set( + new NDIndex(j), + manager.create(X.stream().mapToInt(Integer::intValue).toArray())); + ArrayList Y = data(initialIndicesPerBatch.get(j) + 1, corpus, numSteps); + yNDArray.set( + new NDIndex(j), + manager.create(Y.stream().mapToInt(Integer::intValue).toArray())); + } + NDList pair = new NDList(); + pair.add(xNDArray); + pair.add(yNDArray); + pairs.add(pair); + } + return pairs; + } + + ArrayList data(int pos, List corpus, int numSteps) { + // Return a sequence of length `numSteps` starting from `pos` + return new ArrayList(corpus.subList(pos, pos + numSteps)); + } + + /** Generate a minibatch of subsequences using sequential partitioning. */ + public ArrayList seqDataIterSequential( + List corpus, int batchSize, int numSteps, NDManager manager) { + // Start with a random offset to partition a sequence + int offset = new Random().nextInt(numSteps); + int numTokens = ((corpus.size() - offset - 1) / batchSize) * batchSize; + + NDArray Xs = + manager.create( + corpus.subList(offset, offset + numTokens).stream() + .mapToInt(Integer::intValue) + .toArray()); + NDArray Ys = + manager.create( + corpus.subList(offset + 1, offset + 1 + numTokens).stream() + .mapToInt(Integer::intValue) + .toArray()); + Xs = Xs.reshape(new Shape(batchSize, -1)); + Ys = Ys.reshape(new Shape(batchSize, -1)); + int numBatches = (int) Xs.getShape().get(1) / numSteps; + + ArrayList pairs = new ArrayList(); + for (int i = 0; i < numSteps * numBatches; i += numSteps) { + NDArray X = Xs.get(new NDIndex(":, {}:{}", i, i + numSteps)); + NDArray Y = Ys.get(new NDIndex(":, {}:{}", i, i + numSteps)); + NDList pair = new NDList(); + pair.add(X); + pair.add(Y); + pairs.add(pair); + } + return pairs; + } +} + +/** An RNN Model implemented from scratch. */ +public class RNNModelScratch { + public int vocabSize; + public int numHiddens; + public NDList params; + public Functions.TriFunction initState; + public Functions.TriFunction forwardFn; + + public RNNModelScratch( + int vocabSize, + int numHiddens, + Device device, + Functions.TriFunction getParams, + Functions.TriFunction initRNNState, + Functions.TriFunction forwardFn) { + this.vocabSize = vocabSize; + this.numHiddens = numHiddens; + this.params = getParams.apply(vocabSize, numHiddens, device); + this.initState = initRNNState; + this.forwardFn = forwardFn; + } + + public Pair call(NDArray X, NDArray state) { + X = X.transpose().oneHot(this.vocabSize); + return this.forwardFn.apply(X, state, this.params); + } + + public NDArray beginState(int batchSize, Device device) { + return this.initState.apply(batchSize, this.numHiddens, device); + } +} + +public class RNNModel extends AbstractBlock { + + private static final byte VERSION = 2; + private RNN rnnLayer; + private Linear dense; + private int vocabSize; + + public RNNModel(RNN rnnLayer, int vocabSize) { + super(VERSION); + this.rnnLayer = rnnLayer; + this.addChildBlock("rnn", rnnLayer); + this.vocabSize = vocabSize; + this.dense = Linear.builder().setUnits(vocabSize).build(); + } + + /** {@inheritDoc} */ + @Override + public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { + /* rnnLayer is already initialized so we don't have to do anything here, just override it.*/ + return; + } + + @Override + protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList params) { + NDArray X = inputs.get(0).transpose().oneHot(this.vocabSize); + inputs.set(0, X); + NDList result = this.rnnLayer.forward(parameterStore, inputs, training); + NDArray Y = result.get(0); + NDArray state = result.get(1); + + int shapeLength = Y.getShape().getShape().length; + NDList output = this.dense.forward(parameterStore, new NDList(Y + .reshape(new Shape(-1, Y.getShape().get(shapeLength-1)))), training); + return new NDList(output.get(0), state); + } + + + /* We won't implement this since we won't be using it but it's required as part of an AbstractBlock */ + @Override + public Shape[] getOutputShapes(Shape[] inputShapes) { + return new Shape[0]; + } } public class TimeMachine { - /** - * Split text lines into word or character tokens. - */ + /** Split text lines into word or character tokens. */ public static String[][] tokenize(String[] lines, String token) throws Exception { String[][] output = new String[lines.length][]; if (token == "word") { for (int i = 0; i < output.length; i++) { output[i] = lines[i].split(" "); } - }else if (token == "char") { + } else if (token == "char") { for (int i = 0; i < output.length; i++) { output[i] = lines[i].split(""); } - }else { + } else { throw new Exception("ERROR: unknown token type: " + token); } - return output; + return output; } - /** - * Read `The Time Machine` dataset and return an array of the lines - */ + /** Read `The Time Machine` dataset and return an array of the lines */ public static String[] readTimeMachine() throws IOException { URL url = new URL("http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt"); String[] lines; @@ -128,10 +337,9 @@ public static String[] readTimeMachine() throws IOException { return lines; } - /** - * Return token indices and the vocabulary of the time machine dataset. - */ - public static Pair, Vocab> loadCorpusTimeMachine(int maxTokens) throws IOException, Exception { + /** Return token indices and the vocabulary of the time machine dataset. */ + public static Pair, Vocab> loadCorpusTimeMachine(int maxTokens) + throws IOException, Exception { String[] lines = readTimeMachine(); String[][] tokens = tokenize(lines, "char"); Vocab vocab = new Vocab(tokens, 0, new String[0]); @@ -150,119 +358,260 @@ public static Pair, Vocab> loadCorpusTimeMachine(int maxTokens) th } return new Pair(corpus, vocab); } -} -public class SeqDataLoader implements Iterable { - public ArrayList dataIter; - public List corpus; - public Vocab vocab; - public int batchSize; - public int numSteps; - - /* An iterator to load sequence data. */ - @SuppressWarnings("unchecked") - public SeqDataLoader(int batchSize, int numSteps, boolean useRandomIter, int maxTokens) throws IOException, Exception { - Pair, Vocab> corpusVocabPair = TimeMachine.loadCorpusTimeMachine(maxTokens); - this.corpus = corpusVocabPair.getKey(); - this.vocab = corpusVocabPair.getValue(); - - this.batchSize = batchSize; - this.numSteps = numSteps; - if (useRandomIter) { - dataIter = seqDataIterRandom(corpus, batchSize, numSteps, manager); - }else { - dataIter = seqDataIterSequential(corpus, batchSize, numSteps, manager); + /** Generate new characters following the `prefix`. */ + public static String predictCh8( + String prefix, + int numPreds, + Object net, + Vocab vocab, + Device device, + NDManager manager) { + + List outputs = new ArrayList<>(); + outputs.add(vocab.getIdx("" + prefix.charAt(0))); + Functions.SimpleFunction getInput = + () -> + manager.create(outputs.get(outputs.size() - 1)) + .toDevice(device, false) + .reshape(new Shape(1, 1)); + + if (net instanceof RNNModelScratch) { + RNNModelScratch castedNet = (RNNModelScratch) net; + NDArray state = castedNet.beginState(1, device); + + for (char c : prefix.substring(1).toCharArray()) { // Warm-up period + state = (NDArray) castedNet.call(getInput.apply(), state).getValue(); + outputs.add(vocab.getIdx("" + c)); + } + + NDArray y; + for (int i = 0; i < numPreds; i++) { + Pair pair = castedNet.call(getInput.apply(), state); + y = pair.getKey(); + state = pair.getValue(); + + outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L)); + } + } else { + RNNModel castedNet = (RNNModel) net; + NDArray state = null; + for (char c : prefix.substring(1).toCharArray()) { // Warm-up period + if (state == null) { + // Begin state + state = + castedNet + .forwardInternal( + new ParameterStore(manager, false), + new NDList(getInput.apply()), + false, + null) + .get(1); + } else { + state = + castedNet + .forwardInternal( + new ParameterStore(manager, false), + new NDList(getInput.apply(), state), + false, + null) + .get(1); + } + outputs.add(vocab.getIdx("" + c)); + } + + NDArray y; + for (int i = 0; i < numPreds; i++) { + NDList pair = + castedNet.forwardInternal( + new ParameterStore(manager, false), + new NDList(getInput.apply(), state), + false, + null); + y = pair.get(0); + state = pair.get(1); + + outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L)); + } } - } - - @Override - public Iterator iterator() { - return dataIter.iterator(); - } -} -/** - * Return the iterator and the vocabulary of the time machine dataset. - */ -public Pair, Vocab> loadDataTimeMachine(int batchSize, int numSteps, boolean useRandomIter, int maxTokens) throws IOException, Exception { - - SeqDataLoader seqData = new SeqDataLoader(batchSize, numSteps, useRandomIter, maxTokens); - return new Pair(seqData.dataIter, seqData.vocab); // ArrayList, Vocab -} -/** - * Generate a minibatch of subsequences using random sampling. - */ -public ArrayList - seqDataIterRandom(List corpus, int batchSize, int numSteps, NDManager manager) { - // Start with a random offset (inclusive of `numSteps - 1`) to partition a - // sequence - corpus = corpus.subList(new Random().nextInt(numSteps - 1), corpus.size()); - // Subtract 1 since we need to account for labels - int numSubseqs = (corpus.size() - 1) / numSteps; - // The starting indices for subsequences of length `numSteps` - List initialIndices = new ArrayList<>(); - for (int i = 0; i < numSubseqs * numSteps; i += numSteps) { - initialIndices.add(i); + String outputString = ""; + for (int i : outputs) { + outputString += vocab.idxToToken.get(i); + } + return outputString; } - // In random sampling, the subsequences from two adjacent random - // minibatches during iteration are not necessarily adjacent on the - // original sequence - Collections.shuffle(initialIndices); - - int numBatches = numSubseqs / batchSize; - - ArrayList pairs = new ArrayList(); - for (int i = 0; i < batchSize * numBatches; i += batchSize) { - // Here, `initialIndices` contains randomized starting indices for - // subsequences - List initialIndicesPerBatch = initialIndices.subList(i, i + batchSize); - - NDArray xNDArray = manager.create(new Shape(initialIndicesPerBatch.size(), numSteps), DataType.INT32); - NDArray yNDArray = manager.create(new Shape(initialIndicesPerBatch.size(), numSteps), DataType.INT32); - for (int j = 0; j < initialIndicesPerBatch.size(); j++) { - ArrayList X = data(initialIndicesPerBatch.get(j), corpus, numSteps); - xNDArray.set(new NDIndex(j), manager.create(X.stream().mapToInt(Integer::intValue).toArray())); - ArrayList Y = data(initialIndicesPerBatch.get(j)+1, corpus, numSteps); - yNDArray.set(new NDIndex(j), manager.create(Y.stream().mapToInt(Integer::intValue).toArray())); + + /** Train a model. */ + public static void trainCh8( + Object net, + List trainIter, + Vocab vocab, + int lr, + int numEpochs, + Device device, + boolean useRandomIter, + NDManager manager) { + SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss(); + Animator animator = new Animator(); + + Functions.voidTwoFunction updater; + if (net instanceof RNNModelScratch) { + RNNModelScratch castedNet = (RNNModelScratch) net; + updater = + (batchSize, subManager) -> + Training.sgd(castedNet.params, lr, batchSize, subManager); + } else { + // Already initialized net + RNNModel castedNet = (RNNModel) net; + Model model = Model.newInstance("model"); + model.setBlock(castedNet); + + Tracker lrt = Tracker.fixed(0.1f); + Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build(); + + DefaultTrainingConfig config = + new DefaultTrainingConfig(loss) + .optOptimizer(sgd) // Optimizer (loss function) + .optInitializer( + new NormalInitializer(0.01f), + Parameter.Type.WEIGHT) // setting the initializer + .optDevices(Device.getDevices(1)) // setting the number of GPUs needed + .addEvaluator(new Accuracy()) // Model Accuracy + .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging + + Trainer trainer = model.newTrainer(config); + updater = (batchSize, subManager) -> trainer.step(); + } + + Function predict = + (prefix) -> predictCh8(prefix, 50, net, vocab, device, manager); + // Train and predict + double ppl = 0.0; + double speed = 0.0; + for (int epoch = 0; epoch < numEpochs; epoch++) { + // System.out.println("Epoch: " + epoch); + Pair pair = + trainEpochCh8(net, trainIter, loss, updater, device, useRandomIter, manager); + ppl = pair.getKey(); + speed = pair.getValue(); + if ((epoch + 1) % 10 == 0) { + animator.add(epoch + 1, (float) ppl, "ppl"); + animator.show(); + } + // System.out.format( + // "perplexity: %.1f, %.1f tokens/sec on %s%n", ppl, speed, + // device.toString()); } - NDList pair = new NDList(); - pair.add(xNDArray); - pair.add(yNDArray); - pairs.add(pair); + System.out.format( + "perplexity: %.1f, %.1f tokens/sec on %s%n", ppl, speed, device.toString()); + System.out.println(predict.apply("time traveller")); + System.out.println(predict.apply("traveller")); } - return pairs; -} -ArrayList data(int pos, List corpus, int numSteps) { - // Return a sequence of length `numSteps` starting from `pos` - return new ArrayList(corpus.subList(pos, pos + numSteps)); -} + /** Train a model within one epoch. */ + public static Pair trainEpochCh8( + Object net, + List trainIter, + Loss loss, + Functions.voidTwoFunction updater, + Device device, + boolean useRandomIter, + NDManager manager) { + StopWatch watch = new StopWatch(); + watch.start(); + Accumulator metric = new Accumulator(2); // Sum of training loss, no. of tokens + + try (NDManager childManager = manager.newSubManager()) { + NDArray state = null; + for (NDList pair : trainIter) { + NDArray X = pair.get(0).toDevice(Functions.tryGpu(0), true); + X.attach(childManager); + NDArray Y = pair.get(1).toDevice(Functions.tryGpu(0), true); + Y.attach(childManager); + if (state == null || useRandomIter) { + // Initialize `state` when either it is the first iteration or + // using random sampling + if (net instanceof RNNModelScratch) { + state = + ((RNNModelScratch) net) + .beginState((int) X.getShape().getShape()[0], device); + } + } else { + state.stopGradient(); + } + if (state != null) { + state.attach(childManager); + } + + NDArray y = Y.transpose().reshape(new Shape(-1)); + X = X.toDevice(device, false); + y = y.toDevice(device, false); + try (GradientCollector gc = Engine.getInstance().newGradientCollector()) { + NDArray yHat; + if (net instanceof RNNModelScratch) { + Pair pairResult = ((RNNModelScratch) net).call(X, state); + yHat = pairResult.getKey(); + state = pairResult.getValue(); + } else { + NDList pairResult; + if (state == null) { + // Begin state + pairResult = + ((RNNModel) net) + .forwardInternal( + new ParameterStore(manager, false), + new NDList(X), + true, + null); + } else { + pairResult = + ((RNNModel) net) + .forwardInternal( + new ParameterStore(manager, false), + new NDList(X, state), + true, + null); + } + yHat = pairResult.get(0); + state = pairResult.get(1); + } + + NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean(); + // System.out.println("Loss: " + l.getFloat()); + gc.backward(l); + metric.add(new float[] {l.getFloat() * y.size(), y.size()}); + } + gradClipping(net, 1, childManager); + updater.apply(1, childManager); // Since the `mean` function has been invoked + } + } + return new Pair(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop()); + } -/** - * Generate a minibatch of subsequences using sequential partitioning. - */ -public ArrayList seqDataIterSequential(List corpus, int batchSize, int numSteps, NDManager manager) { - // Start with a random offset to partition a sequence - int offset = new Random().nextInt(numSteps); - int numTokens = ((corpus.size() - offset - 1) / batchSize) * batchSize; - - NDArray Xs = manager.create( - corpus.subList(offset, offset + numTokens).stream().mapToInt(Integer::intValue).toArray()); - NDArray Ys = manager.create( - corpus.subList(offset + 1, offset + 1 + numTokens).stream().mapToInt(Integer::intValue).toArray()); - Xs = Xs.reshape(new Shape(batchSize, -1)); - Ys = Ys.reshape(new Shape(batchSize, -1)); - int numBatches = (int) Xs.getShape().get(1) / numSteps; - - - ArrayList pairs = new ArrayList(); - for (int i = 0; i < numSteps * numBatches; i += numSteps) { - NDArray X = Xs.get(new NDIndex(":, {}:{}", i, i + numSteps)); - NDArray Y = Ys.get(new NDIndex(":, {}:{}", i, i + numSteps)); - NDList pair = new NDList(); - pair.add(X); - pair.add(Y); - pairs.add(pair); + /** Clip the gradient. */ + public static void gradClipping(Object net, int theta, NDManager manager) { + double result = 0; + NDList params; + if (net instanceof RNNModelScratch) { + params = ((RNNModelScratch) net).params; + } else { + params = new NDList(); + for (Pair pair : ((RNNModel) net).getParameters()) { + params.add(pair.getValue().getArray()); + } + } + for (NDArray p : params) { + NDArray gradient = p.getGradient().stopGradient(); + gradient.attach(manager); + result += gradient.pow(2).sum().getFloat(); + } + double norm = Math.sqrt(result); + if (norm > theta) { + for (NDArray param : params) { + NDArray gradient = param.getGradient().stopGradient(); + param.getGradient().set(new NDIndex(":"), gradient.mul(theta / norm)); + } + } } - return pairs; -} \ No newline at end of file +} diff --git a/utils/Training.java b/utils/Training.java index 343552cf..44b19b1c 100644 --- a/utils/Training.java +++ b/utils/Training.java @@ -1,17 +1,18 @@ -import ai.djl.ndarray.*; import ai.djl.metric.Metrics; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; import ai.djl.training.EasyTrain; import ai.djl.training.Trainer; import ai.djl.training.dataset.ArrayDataset; - -import java.io.IOException; -import java.util.Map; -import ai.djl.ndarray.types.DataType; import ai.djl.training.dataset.Batch; import ai.djl.translate.TranslateException; -import java.util.function.UnaryOperator; +import java.io.IOException; +import java.util.Map; import java.util.function.BinaryOperator; +import java.util.function.UnaryOperator; class Training { From 122e8bffd0735bce302b466999073f835ebdd8a5 Mon Sep 17 00:00:00 2001 From: markbookk Date: Wed, 10 Mar 2021 11:20:52 -0400 Subject: [PATCH 11/20] Changing RNNModels to AbstractBlocks --- utils/TimeMachineUtils.java | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/utils/TimeMachineUtils.java b/utils/TimeMachineUtils.java index 21024404..ee9c40b9 100644 --- a/utils/TimeMachineUtils.java +++ b/utils/TimeMachineUtils.java @@ -394,27 +394,25 @@ public static String predictCh8( outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L)); } } else { - RNNModel castedNet = (RNNModel) net; + AbstractBlock castedNet = (AbstractBlock) net; NDArray state = null; for (char c : prefix.substring(1).toCharArray()) { // Warm-up period if (state == null) { // Begin state state = castedNet - .forwardInternal( + .forward( new ParameterStore(manager, false), new NDList(getInput.apply()), - false, - null) + false) .get(1); } else { state = castedNet - .forwardInternal( + .forward( new ParameterStore(manager, false), new NDList(getInput.apply(), state), - false, - null) + false) .get(1); } outputs.add(vocab.getIdx("" + c)); @@ -423,11 +421,10 @@ public static String predictCh8( NDArray y; for (int i = 0; i < numPreds; i++) { NDList pair = - castedNet.forwardInternal( + castedNet.forward( new ParameterStore(manager, false), new NDList(getInput.apply(), state), - false, - null); + false); y = pair.get(0); state = pair.get(1); @@ -463,7 +460,7 @@ public static void trainCh8( Training.sgd(castedNet.params, lr, batchSize, subManager); } else { // Already initialized net - RNNModel castedNet = (RNNModel) net; + AbstractBlock castedNet = (AbstractBlock) net; Model model = Model.newInstance("model"); model.setBlock(castedNet); @@ -558,20 +555,18 @@ public static Pair trainEpochCh8( if (state == null) { // Begin state pairResult = - ((RNNModel) net) - .forwardInternal( + ((AbstractBlock) net) + .forward( new ParameterStore(manager, false), new NDList(X), - true, - null); + true); } else { pairResult = - ((RNNModel) net) - .forwardInternal( + ((AbstractBlock) net) + .forward( new ParameterStore(manager, false), new NDList(X, state), - true, - null); + true); } yHat = pairResult.get(0); state = pairResult.get(1); @@ -597,7 +592,7 @@ public static void gradClipping(Object net, int theta, NDManager manager) { params = ((RNNModelScratch) net).params; } else { params = new NDList(); - for (Pair pair : ((RNNModel) net).getParameters()) { + for (Pair pair : ((AbstractBlock) net).getParameters()) { params.add(pair.getValue().getArray()); } } From 7591ef496e19bc73b55421b8d562e43c1c1233c1 Mon Sep 17 00:00:00 2001 From: markbookk Date: Wed, 10 Mar 2021 11:27:06 -0400 Subject: [PATCH 12/20] Adding dense layer as a child block --- .../rnn-concise.ipynb | 23 +------------------ 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/chapter_recurrent-neural-networks/rnn-concise.ipynb b/chapter_recurrent-neural-networks/rnn-concise.ipynb index f6d3bc55..4359355f 100644 --- a/chapter_recurrent-neural-networks/rnn-concise.ipynb +++ b/chapter_recurrent-neural-networks/rnn-concise.ipynb @@ -50,28 +50,6 @@ "%load ../utils/Accumulator.java\n", "%load ../utils/Animator.java\n", "%load ../utils/Training.java\n", - "\n", - "\n", - "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Functions.java\n", - "// // %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/PlotUtils.java\n", - "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/StopWatch.java\n", - "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Accumulator.java\n", - "// // %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Animator.java\n", - "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Training.java\n", - "\n", - "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/RNNModelScratch.java\n", - "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/Vocab.java\n", - "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/SeqDataLoader.java\n", - "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/RNNModel.java\n", - "// %load /Users/nivesmn/Documents/projects/intellij_notebooks/section8_6/src/main/java/TimeMachine.java" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ "%load ../utils/TimeMachineUtils.java" ] }, @@ -295,6 +273,7 @@ " this.addChildBlock(\"rnn\", rnnLayer);\n", " this.vocabSize = vocabSize;\n", " this.dense = Linear.builder().setUnits(vocabSize).build();\n", + " this.addChildBlock(\"linear\", dense);\n", " }\n", "\n", " \n", From 616f5b1a205c5f26065c3cb0fe5556d0a4b55ec6 Mon Sep 17 00:00:00 2001 From: markbookk Date: Mon, 15 Mar 2021 10:07:18 -0400 Subject: [PATCH 13/20] Fixing bug (learning rate was not set appropiately), adding TimeMachineDataset, cleaning code, and adding documentation --- .../rnn-concise.ipynb | 278 +++++++++++------- utils/TimeMachineUtils.java | 144 ++++++++- 2 files changed, 298 insertions(+), 124 deletions(-) diff --git a/chapter_recurrent-neural-networks/rnn-concise.ipynb b/chapter_recurrent-neural-networks/rnn-concise.ipynb index 4359355f..1fbae4f3 100644 --- a/chapter_recurrent-neural-networks/rnn-concise.ipynb +++ b/chapter_recurrent-neural-networks/rnn-concise.ipynb @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -38,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -55,7 +55,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -78,26 +78,174 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "NDManager manager = NDManager.newBaseManager(Functions.tryGpu(0));" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a Dataset in DJL" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In DJL, the ideal and concise way of dealing with datasets, is to use the built-in datasets that can easily wrap around existing NDArrays or to create your own dataset that extends from the `RandomAccessDataset` class. For this section, we will be implementing our own. For more information on creating your own dataset in DJL, you can refer to: https://djl.ai/docs/development/how_to_use_dataset.html\n", + "\n", + "Our implementation of `TimeMachineDataset` will be a concise replacement of the `SeqDataLoader` class previously created. Using a dataset in DJL format, will allow us to use already built-in functions so we don't have to implement most things from scratch. We have to implement a Builder, a prepare function which will contain the process to save the data to the TimeMachineDataset object, and finally a get function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "public class TimeMachineDataset extends RandomAccessDataset {\n", + "\n", + " private Vocab vocab;\n", + " private NDArray data;\n", + " private NDArray labels;\n", + " private int numSteps;\n", + " private int maxTokens;\n", + " private int batchSize;\n", + " private NDManager manager;\n", + " private boolean prepared;\n", + "\n", + " public TimeMachineDataset(Builder builder) {\n", + " super(builder);\n", + " this.numSteps = builder.numSteps;\n", + " this.maxTokens = builder.maxTokens;\n", + " this.batchSize = builder.getSampler().getBatchSize();\n", + " this.manager = builder.manager;\n", + " this.data = this.manager.create(new Shape(0,35), DataType.INT32);\n", + " this.labels = this.manager.create(new Shape(0,35), DataType.INT32);\n", + " this.prepared = false;\n", + " }\n", + "\n", + " @Override\n", + " public Record get(NDManager manager, long index) throws IOException {\n", + " NDArray X = data.get(new NDIndex(\"{}\", index));\n", + " NDArray Y = labels.get(new NDIndex(\"{}\", index));\n", + " return new Record(new NDList(X), new NDList(Y));\n", + " }\n", + "\n", + " @Override\n", + " protected long availableSize() {\n", + " return data.getShape().get(0);\n", + " }\n", + "\n", + " @Override\n", + " public void prepare(Progress progress) throws IOException, TranslateException {\n", + " if (prepared) {\n", + " return;\n", + " }\n", + "\n", + " Pair, Vocab> corpusVocabPair = null;\n", + " try {\n", + " corpusVocabPair = TimeMachine.loadCorpusTimeMachine(maxTokens);\n", + " } catch (Exception e) {\n", + " e.printStackTrace(); // Exception can be from unknown token type during tokenize() function.\n", + " }\n", + " List corpus = corpusVocabPair.getKey();\n", + " this.vocab = corpusVocabPair.getValue();\n", + "\n", + " // Start with a random offset (inclusive of `numSteps - 1`) to partition a\n", + " // sequence\n", + " int offset = new Random().nextInt(numSteps);\n", + " int numTokens = ((int) ((corpus.size() - offset - 1) / batchSize)) * batchSize;\n", + " NDArray Xs =\n", + " manager.create(\n", + " corpus.subList(offset, offset + numTokens).stream()\n", + " .mapToInt(Integer::intValue)\n", + " .toArray());\n", + " NDArray Ys =\n", + " manager.create(\n", + " corpus.subList(offset + 1, offset + 1 + numTokens).stream()\n", + " .mapToInt(Integer::intValue)\n", + " .toArray());\n", + " Xs = Xs.reshape(new Shape(batchSize, -1));\n", + " Ys = Ys.reshape(new Shape(batchSize, -1));\n", + " int numBatches = (int) Xs.getShape().get(1) / numSteps;\n", + "\n", + " for (int i = 0; i < numSteps * numBatches; i += numSteps) {\n", + " NDArray X = Xs.get(new NDIndex(\":, {}:{}\", i, i + numSteps));\n", + " NDArray Y = Ys.get(new NDIndex(\":, {}:{}\", i, i + numSteps));\n", + " // Temp variables to be able to detach NDArray which will be replaced\n", + " NDArray temp = this.data;\n", + " NDArray temp2 = this.data;\n", + " this.data = this.data.concat(X);\n", + " this.labels = this.labels.concat(Y);\n", + " temp.detach();\n", + " temp2.detach();\n", + " }\n", + " this.prepared = true;\n", + " }\n", + "\n", + " public Vocab getVocab() {\n", + " return this.vocab;\n", + " }\n", + "\n", + " public static final class Builder extends BaseBuilder {\n", + "\n", + " int numSteps;\n", + " int maxTokens;\n", + " NDManager manager;\n", + "\n", + "\n", + " @Override\n", + " protected Builder self() { return this; }\n", + "\n", + " public Builder setSteps(int steps) {\n", + " this.numSteps = steps;\n", + " return this;\n", + " }\n", + "\n", + " public Builder setMaxTokens(int maxTokens) {\n", + " this.maxTokens = maxTokens;\n", + " return this;\n", + " }\n", + "\n", + " public Builder setManager(NDManager manager) {\n", + " this.manager = manager;\n", + " return this;\n", + " }\n", + "\n", + " public TimeMachineDataset build() throws IOException, TranslateException {\n", + " TimeMachineDataset dataset = new TimeMachineDataset(this);\n", + " return dataset;\n", + " }\n", + " }\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we will leverage the dataset that we just created and assign the required parameters." + ] + }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "int batchSize = 32;\n", "int numSteps = 35;\n", "\n", - "Pair, Vocab> timeMachine =\n", - " SeqDataLoader.loadDataTimeMachine(batchSize, numSteps, false, 10000, manager);\n", - "ArrayList trainIter = timeMachine.getKey();\n", - "Vocab vocab = timeMachine.getValue();" + "TimeMachineDataset dataset = new TimeMachineDataset.Builder()\n", + " .setManager(manager).setMaxTokens(10000).setSampling(batchSize, false)\n", + " .setSteps(numSteps).build();\n", + "dataset.prepare();\n", + "Vocab vocab = dataset.getVocab();" ] }, { @@ -116,7 +264,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -151,18 +299,9 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1\n", - "(1, 32, 256)\n" - ] - } - ], + "outputs": [], "source": [ "public static NDList beginState(int batchSize, int numLayers, int numHiddens) {\n", " return new NDList(manager.zeros(new Shape(numLayers, batchSize, numHiddens)));\n", @@ -219,18 +358,9 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(35, 32, 256)\n", - "(1, 32, 256)\n" - ] - } - ], + "outputs": [], "source": [ "NDArray X = manager.randomUniform (0, 1,new Shape(numSteps, batchSize, vocab.length()));\n", "\n", @@ -256,7 +386,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -276,7 +406,6 @@ " this.addChildBlock(\"linear\", dense);\n", " }\n", "\n", - " \n", " @Override\n", " protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList params) {\n", " NDArray X = inputs.get(0).transpose().oneHot(this.vocabSize);\n", @@ -318,17 +447,9 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "time travellerddjjjjjdjj\n" - ] - } - ], + "outputs": [], "source": [ "Device device = Functions.tryGpu(0);\n", "RNNModel net = new RNNModel(rnnLayer, vocab.length());\n", @@ -348,76 +469,13 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Training on: cpu().\n", - "[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Load MXNet Engine Version 1.7.0 in 0.069 ms.\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - "\n" - ], - "text/plain": [ - "tech.tablesaw.plotly.components.Figure@4ed507e1" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "perplexity: 5.6, 69994.2 tokens/sec on cpu()\n", - "time traveller ou there thoe int and thene ta the this thaw ol \n", - "travellertand mh the thme travel er oae sooeee te tou than \n" - ] - } - ], + "outputs": [], "source": [ "int numEpochs = 500;\n", "int lr = 1;\n", - "TimeMachine.trainCh8((Object) net, trainIter, vocab, lr, numEpochs, device, false, manager);" + "TimeMachine.trainCh8((Object) net, dataset, vocab, lr, numEpochs, device, false, manager);" ] }, { diff --git a/utils/TimeMachineUtils.java b/utils/TimeMachineUtils.java index ee9c40b9..50ab5bea 100644 --- a/utils/TimeMachineUtils.java +++ b/utils/TimeMachineUtils.java @@ -13,7 +13,9 @@ import ai.djl.training.initializer.*; import ai.djl.training.evaluator.*; import ai.djl.training.optimizer.*; +import ai.djl.training.dataset.*; import ai.djl.training.listener.*; +import ai.djl.translate.TranslateException; import java.io.*; import java.net.URL; @@ -442,13 +444,14 @@ public static String predictCh8( /** Train a model. */ public static void trainCh8( Object net, - List trainIter, + RandomAccessDataset dataset, Vocab vocab, int lr, int numEpochs, Device device, boolean useRandomIter, - NDManager manager) { + NDManager manager) + throws IOException, TranslateException { SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss(); Animator animator = new Animator(); @@ -464,7 +467,7 @@ public static void trainCh8( Model model = Model.newInstance("model"); model.setBlock(castedNet); - Tracker lrt = Tracker.fixed(0.1f); + Tracker lrt = Tracker.fixed(lr); Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build(); DefaultTrainingConfig config = @@ -487,18 +490,14 @@ public static void trainCh8( double ppl = 0.0; double speed = 0.0; for (int epoch = 0; epoch < numEpochs; epoch++) { - // System.out.println("Epoch: " + epoch); Pair pair = - trainEpochCh8(net, trainIter, loss, updater, device, useRandomIter, manager); + trainEpochCh8(net, dataset, loss, updater, device, useRandomIter, manager); ppl = pair.getKey(); speed = pair.getValue(); if ((epoch + 1) % 10 == 0) { animator.add(epoch + 1, (float) ppl, "ppl"); animator.show(); } - // System.out.format( - // "perplexity: %.1f, %.1f tokens/sec on %s%n", ppl, speed, - // device.toString()); } System.out.format( "perplexity: %.1f, %.1f tokens/sec on %s%n", ppl, speed, device.toString()); @@ -509,22 +508,23 @@ public static void trainCh8( /** Train a model within one epoch. */ public static Pair trainEpochCh8( Object net, - List trainIter, + RandomAccessDataset dataset, Loss loss, Functions.voidTwoFunction updater, Device device, boolean useRandomIter, - NDManager manager) { + NDManager manager) + throws IOException, TranslateException { StopWatch watch = new StopWatch(); watch.start(); Accumulator metric = new Accumulator(2); // Sum of training loss, no. of tokens try (NDManager childManager = manager.newSubManager()) { NDArray state = null; - for (NDList pair : trainIter) { - NDArray X = pair.get(0).toDevice(Functions.tryGpu(0), true); + for (Batch batch : dataset.getData(manager)) { + NDArray X = batch.getData().head().toDevice(Functions.tryGpu(0), true); X.attach(childManager); - NDArray Y = pair.get(1).toDevice(Functions.tryGpu(0), true); + NDArray Y = batch.getLabels().head().toDevice(Functions.tryGpu(0), true); Y.attach(childManager); if (state == null || useRandomIter) { // Initialize `state` when either it is the first iteration or @@ -573,7 +573,6 @@ public static Pair trainEpochCh8( } NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean(); - // System.out.println("Loss: " + l.getFloat()); gc.backward(l); metric.add(new float[] {l.getFloat() * y.size(), y.size()}); } @@ -610,3 +609,120 @@ public static void gradClipping(Object net, int theta, NDManager manager) { } } } + +public class TimeMachineDataset extends RandomAccessDataset { + + private Vocab vocab; + private NDArray data; + private NDArray labels; + private int numSteps; + private int maxTokens; + private int batchSize; + private NDManager manager; + private boolean prepared; + + public TimeMachineDataset(Builder builder) { + super(builder); + this.numSteps = builder.numSteps; + this.maxTokens = builder.maxTokens; + this.batchSize = builder.getSampler().getBatchSize(); + this.manager = builder.manager; + this.data = this.manager.create(new Shape(0,35), DataType.INT32); + this.labels = this.manager.create(new Shape(0,35), DataType.INT32); + this.prepared = false; + } + + @Override + public Record get(NDManager manager, long index) throws IOException { + NDArray X = data.get(new NDIndex("{}", index)); + NDArray Y = labels.get(new NDIndex("{}", index)); + return new Record(new NDList(X), new NDList(Y)); + } + + @Override + protected long availableSize() { + return data.getShape().get(0); + } + + @Override + public void prepare(Progress progress) throws IOException, TranslateException { + if (prepared) { + return; + } + + Pair, Vocab> corpusVocabPair = null; + try { + corpusVocabPair = TimeMachine.loadCorpusTimeMachine(maxTokens); + } catch (Exception e) { + e.printStackTrace(); // Exception can be from unknown token type during tokenize() function. + } + List corpus = corpusVocabPair.getKey(); + this.vocab = corpusVocabPair.getValue(); + + // Start with a random offset (inclusive of `numSteps - 1`) to partition a + // sequence + int offset = new Random().nextInt(numSteps); + int numTokens = ((int) ((corpus.size() - offset - 1) / batchSize)) * batchSize; + NDArray Xs = + manager.create( + corpus.subList(offset, offset + numTokens).stream() + .mapToInt(Integer::intValue) + .toArray()); + NDArray Ys = + manager.create( + corpus.subList(offset + 1, offset + 1 + numTokens).stream() + .mapToInt(Integer::intValue) + .toArray()); + Xs = Xs.reshape(new Shape(batchSize, -1)); + Ys = Ys.reshape(new Shape(batchSize, -1)); + int numBatches = (int) Xs.getShape().get(1) / numSteps; + + for (int i = 0; i < numSteps * numBatches; i += numSteps) { + NDArray X = Xs.get(new NDIndex(":, {}:{}", i, i + numSteps)); + NDArray Y = Ys.get(new NDIndex(":, {}:{}", i, i + numSteps)); + // Temp variables to be able to detach NDArray which will be replaced + NDArray temp = this.data; + NDArray temp2 = this.data; + this.data = this.data.concat(X); + this.labels = this.labels.concat(Y); + temp.detach(); + temp2.detach(); + } + this.prepared = true; + } + + public Vocab getVocab() { + return this.vocab; + } + + public static final class Builder extends BaseBuilder { + + int numSteps; + int maxTokens; + NDManager manager; + + + @Override + protected Builder self() { return this; } + + public Builder setSteps(int steps) { + this.numSteps = steps; + return this; + } + + public Builder setMaxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder setManager(NDManager manager) { + this.manager = manager; + return this; + } + + public TimeMachineDataset build() throws IOException, TranslateException { + TimeMachineDataset dataset = new TimeMachineDataset(this); + return dataset; + } + } +} From dfb74496358e37e0d6de7b419beefb35ed325023 Mon Sep 17 00:00:00 2001 From: markbookk Date: Mon, 15 Mar 2021 12:32:36 -0400 Subject: [PATCH 14/20] Changing HashMap to LinkedHashMap to keep order of dataset --- utils/TimeMachineUtils.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/utils/TimeMachineUtils.java b/utils/TimeMachineUtils.java index 50ab5bea..cb89ddff 100644 --- a/utils/TimeMachineUtils.java +++ b/utils/TimeMachineUtils.java @@ -30,7 +30,7 @@ public class Vocab { public Vocab(String[][] tokens, int minFreq, String[] reservedTokens) { // Sort according to frequencies - HashMap counter = countCorpus2D(tokens); + LinkedHashMap counter = countCorpus2D(tokens); this.tokenFreqs = new ArrayList>(counter.entrySet()); Collections.sort( tokenFreqs, @@ -77,9 +77,9 @@ public Integer getIdx(String token) { } /** Count token frequencies. */ - public HashMap countCorpus(String[] tokens) { + public LinkedHashMap countCorpus(String[] tokens) { - HashMap counter = new HashMap<>(); + LinkedHashMap counter = new LinkedHashMap<>(); if (tokens.length != 0) { for (String token : tokens) { counter.put(token, counter.getOrDefault(token, 0) + 1); @@ -89,7 +89,7 @@ public HashMap countCorpus(String[] tokens) { } /** Flatten a list of token lists into a list of tokens */ - public HashMap countCorpus2D(String[][] tokens) { + public LinkedHashMap countCorpus2D(String[][] tokens) { List allTokens = new ArrayList(); for (int i = 0; i < tokens.length; i++) { for (int j = 0; j < tokens[i].length; j++) { From 0c470ad74aca9caaff2af5de67fd719bf1992a0c Mon Sep 17 00:00:00 2001 From: markbookk Date: Mon, 15 Mar 2021 16:33:19 -0400 Subject: [PATCH 15/20] Adding suggestions on previous section's code review and adding implementation to notebook --- .../rnn-concise.ipynb | 287 +++++++++++++++++- utils/TimeMachineUtils.java | 16 +- 2 files changed, 293 insertions(+), 10 deletions(-) diff --git a/chapter_recurrent-neural-networks/rnn-concise.ipynb b/chapter_recurrent-neural-networks/rnn-concise.ipynb index 1fbae4f3..a6a52587 100644 --- a/chapter_recurrent-neural-networks/rnn-concise.ipynb +++ b/chapter_recurrent-neural-networks/rnn-concise.ipynb @@ -225,6 +225,289 @@ "}" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Consequently we will update our code from the previous section for the functions `predictCh8`, `trainCh8`, `trainEpochCh8`, and `gradClipping` to include the dataset logic and also allow the functions to accept an `AbstractBlock` from DJL instead of just accepting `RNNModelScratch`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "/** Generate new characters following the `prefix`. */\n", + "public static String predictCh8(\n", + " String prefix,\n", + " int numPreds,\n", + " Object net,\n", + " Vocab vocab,\n", + " Device device,\n", + " NDManager manager) {\n", + "\n", + " List outputs = new ArrayList<>();\n", + " outputs.add(vocab.getIdx(\"\" + prefix.charAt(0)));\n", + " Functions.SimpleFunction getInput =\n", + " () ->\n", + " manager.create(outputs.get(outputs.size() - 1))\n", + " .toDevice(device, false)\n", + " .reshape(new Shape(1, 1));\n", + "\n", + " if (net instanceof RNNModelScratch) {\n", + " RNNModelScratch castedNet = (RNNModelScratch) net;\n", + " NDArray state = castedNet.beginState(1, device);\n", + "\n", + " for (char c : prefix.substring(1).toCharArray()) { // Warm-up period\n", + " state = (NDArray) castedNet.forward(getInput.apply(), state).getValue();\n", + " outputs.add(vocab.getIdx(\"\" + c));\n", + " }\n", + "\n", + " NDArray y;\n", + " for (int i = 0; i < numPreds; i++) {\n", + " Pair pair = castedNet.forward(getInput.apply(), state);\n", + " y = pair.getKey();\n", + " state = pair.getValue();\n", + "\n", + " outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));\n", + " }\n", + " } else {\n", + " AbstractBlock castedNet = (AbstractBlock) net;\n", + " NDArray state = null;\n", + " for (char c : prefix.substring(1).toCharArray()) { // Warm-up period\n", + " if (state == null) {\n", + " // Begin state\n", + " state =\n", + " castedNet\n", + " .forward(\n", + " new ParameterStore(manager, false),\n", + " new NDList(getInput.apply()),\n", + " false)\n", + " .get(1);\n", + " } else {\n", + " state =\n", + " castedNet\n", + " .forward(\n", + " new ParameterStore(manager, false),\n", + " new NDList(getInput.apply(), state),\n", + " false)\n", + " .get(1);\n", + " }\n", + " outputs.add(vocab.getIdx(\"\" + c));\n", + " }\n", + "\n", + " NDArray y;\n", + " for (int i = 0; i < numPreds; i++) {\n", + " NDList pair =\n", + " castedNet.forward(\n", + " new ParameterStore(manager, false),\n", + " new NDList(getInput.apply(), state),\n", + " false);\n", + " y = pair.get(0);\n", + " state = pair.get(1);\n", + "\n", + " outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));\n", + " }\n", + " }\n", + "\n", + " StringBuilder output = new StringBuilder();\n", + " for (int i : outputs) {\n", + " output.append(vocab.idxToToken.get(i));\n", + " }\n", + " return output.toString();\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "/** Train a model. */\n", + "public static void trainCh8(\n", + " Object net,\n", + " RandomAccessDataset dataset,\n", + " Vocab vocab,\n", + " int lr,\n", + " int numEpochs,\n", + " Device device,\n", + " boolean useRandomIter,\n", + " NDManager manager)\n", + " throws IOException, TranslateException {\n", + " SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss();\n", + " Animator animator = new Animator();\n", + "\n", + " Functions.voidTwoFunction updater;\n", + " if (net instanceof RNNModelScratch) {\n", + " RNNModelScratch castedNet = (RNNModelScratch) net;\n", + " updater =\n", + " (batchSize, subManager) ->\n", + " Training.sgd(castedNet.params, lr, batchSize, subManager);\n", + " } else {\n", + " // Already initialized net\n", + " AbstractBlock castedNet = (AbstractBlock) net;\n", + " Model model = Model.newInstance(\"model\");\n", + " model.setBlock(castedNet);\n", + "\n", + " Tracker lrt = Tracker.fixed(lr);\n", + " Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();\n", + "\n", + " DefaultTrainingConfig config =\n", + " new DefaultTrainingConfig(loss)\n", + " .optOptimizer(sgd) // Optimizer (loss function)\n", + " .optInitializer(\n", + " new NormalInitializer(0.01f),\n", + " Parameter.Type.WEIGHT) // setting the initializer\n", + " .optDevices(Device.getDevices(1)) // setting the number of GPUs needed\n", + " .addEvaluator(new Accuracy()) // Model Accuracy\n", + " .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n", + "\n", + " Trainer trainer = model.newTrainer(config);\n", + " updater = (batchSize, subManager) -> trainer.step();\n", + " }\n", + "\n", + " Function predict =\n", + " (prefix) -> predictCh8(prefix, 50, net, vocab, device, manager);\n", + " // Train and predict\n", + " double ppl = 0.0;\n", + " double speed = 0.0;\n", + " for (int epoch = 0; epoch < numEpochs; epoch++) {\n", + " Pair pair =\n", + " trainEpochCh8(net, dataset, loss, updater, device, useRandomIter, manager);\n", + " ppl = pair.getKey();\n", + " speed = pair.getValue();\n", + " if ((epoch + 1) % 10 == 0) {\n", + " animator.add(epoch + 1, (float) ppl, \"ppl\");\n", + " animator.show();\n", + " }\n", + " }\n", + " System.out.format(\n", + " \"perplexity: %.1f, %.1f tokens/sec on %s%n\", ppl, speed, device.toString());\n", + " System.out.println(predict.apply(\"time traveller\"));\n", + " System.out.println(predict.apply(\"traveller\"));\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "/** Train a model within one epoch. */\n", + "public static Pair trainEpochCh8(\n", + " Object net,\n", + " RandomAccessDataset dataset,\n", + " Loss loss,\n", + " Functions.voidTwoFunction updater,\n", + " Device device,\n", + " boolean useRandomIter,\n", + " NDManager manager)\n", + " throws IOException, TranslateException {\n", + " StopWatch watch = new StopWatch();\n", + " watch.start();\n", + " Accumulator metric = new Accumulator(2); // Sum of training loss, no. of tokens\n", + "\n", + " try (NDManager childManager = manager.newSubManager()) {\n", + " NDArray state = null;\n", + " for (Batch batch : dataset.getData(manager)) {\n", + " NDArray X = batch.getData().head().toDevice(Functions.tryGpu(0), true);\n", + " X.attach(childManager);\n", + " NDArray Y = batch.getLabels().head().toDevice(Functions.tryGpu(0), true);\n", + " Y.attach(childManager);\n", + " if (state == null || useRandomIter) {\n", + " // Initialize `state` when either it is the first iteration or\n", + " // using random sampling\n", + " if (net instanceof RNNModelScratch) {\n", + " state =\n", + " ((RNNModelScratch) net)\n", + " .beginState((int) X.getShape().getShape()[0], device);\n", + " }\n", + " } else {\n", + " state.stopGradient();\n", + " }\n", + " if (state != null) {\n", + " state.attach(childManager);\n", + " }\n", + "\n", + " NDArray y = Y.transpose().reshape(new Shape(-1));\n", + " X = X.toDevice(device, false);\n", + " y = y.toDevice(device, false);\n", + " try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n", + " NDArray yHat;\n", + " if (net instanceof RNNModelScratch) {\n", + " Pair pairResult = ((RNNModelScratch) net).forward(X, state);\n", + " yHat = pairResult.getKey();\n", + " state = pairResult.getValue();\n", + " } else {\n", + " NDList pairResult;\n", + " if (state == null) {\n", + " // Begin state\n", + " pairResult =\n", + " ((AbstractBlock) net)\n", + " .forward(\n", + " new ParameterStore(manager, false),\n", + " new NDList(X),\n", + " true);\n", + " } else {\n", + " pairResult =\n", + " ((AbstractBlock) net)\n", + " .forward(\n", + " new ParameterStore(manager, false),\n", + " new NDList(X, state),\n", + " true);\n", + " }\n", + " yHat = pairResult.get(0);\n", + " state = pairResult.get(1);\n", + " }\n", + "\n", + " NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean();\n", + " gc.backward(l);\n", + " metric.add(new float[] {l.getFloat() * y.size(), y.size()});\n", + " }\n", + " gradClipping(net, 1, childManager);\n", + " updater.apply(1, childManager); // Since the `mean` function has been invoked\n", + " }\n", + " }\n", + " return new Pair<>(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop());\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "/** Clip the gradient. */\n", + "public static void gradClipping(Object net, int theta, NDManager manager) {\n", + " double result = 0;\n", + " NDList params;\n", + " if (net instanceof RNNModelScratch) {\n", + " params = ((RNNModelScratch) net).params;\n", + " } else {\n", + " params = new NDList();\n", + " for (Pair pair : ((AbstractBlock) net).getParameters()) {\n", + " params.add(pair.getValue().getArray());\n", + " }\n", + " }\n", + " for (NDArray p : params) {\n", + " NDArray gradient = p.getGradient().stopGradient();\n", + " gradient.attach(manager);\n", + " result += gradient.pow(2).sum().getFloat();\n", + " }\n", + " double norm = Math.sqrt(result);\n", + " if (norm > theta) {\n", + " for (NDArray param : params) {\n", + " NDArray gradient = param.getGradient().stopGradient();\n", + " param.getGradient().set(new NDIndex(\":\"), gradient.mul(theta / norm));\n", + " }\n", + " }\n", + "}" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -454,7 +737,7 @@ "Device device = Functions.tryGpu(0);\n", "RNNModel net = new RNNModel(rnnLayer, vocab.length());\n", "net.initialize(manager, DataType.FLOAT32, X.getShape());\n", - "String prediction = TimeMachine.predictCh8(\"time traveller\", 10, net, vocab, device, manager);\n", + "String prediction = predictCh8(\"time traveller\", 10, net, vocab, device, manager);\n", "System.out.println(prediction);" ] }, @@ -475,7 +758,7 @@ "source": [ "int numEpochs = 500;\n", "int lr = 1;\n", - "TimeMachine.trainCh8((Object) net, dataset, vocab, lr, numEpochs, device, false, manager);" + "trainCh8((Object) net, dataset, vocab, lr, numEpochs, device, false, manager);" ] }, { diff --git a/utils/TimeMachineUtils.java b/utils/TimeMachineUtils.java index cb89ddff..7a5f3117 100644 --- a/utils/TimeMachineUtils.java +++ b/utils/TimeMachineUtils.java @@ -253,7 +253,7 @@ public RNNModelScratch( this.forwardFn = forwardFn; } - public Pair call(NDArray X, NDArray state) { + public Pair forward(NDArray X, NDArray state) { X = X.transpose().oneHot(this.vocabSize); return this.forwardFn.apply(X, state, this.params); } @@ -383,13 +383,13 @@ public static String predictCh8( NDArray state = castedNet.beginState(1, device); for (char c : prefix.substring(1).toCharArray()) { // Warm-up period - state = (NDArray) castedNet.call(getInput.apply(), state).getValue(); + state = (NDArray) castedNet.forward(getInput.apply(), state).getValue(); outputs.add(vocab.getIdx("" + c)); } NDArray y; for (int i = 0; i < numPreds; i++) { - Pair pair = castedNet.call(getInput.apply(), state); + Pair pair = castedNet.forward(getInput.apply(), state); y = pair.getKey(); state = pair.getValue(); @@ -434,11 +434,11 @@ public static String predictCh8( } } - String outputString = ""; + StringBuilder output = new StringBuilder(); for (int i : outputs) { - outputString += vocab.idxToToken.get(i); + output.append(vocab.idxToToken.get(i)); } - return outputString; + return output.toString(); } /** Train a model. */ @@ -547,7 +547,7 @@ public static Pair trainEpochCh8( try (GradientCollector gc = Engine.getInstance().newGradientCollector()) { NDArray yHat; if (net instanceof RNNModelScratch) { - Pair pairResult = ((RNNModelScratch) net).call(X, state); + Pair pairResult = ((RNNModelScratch) net).forward(X, state); yHat = pairResult.getKey(); state = pairResult.getValue(); } else { @@ -580,7 +580,7 @@ public static Pair trainEpochCh8( updater.apply(1, childManager); // Since the `mean` function has been invoked } } - return new Pair(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop()); + return new Pair<>(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop()); } /** Clip the gradient. */ From 72ef66bd5b7bfa0d3d5c19f1498a5722242ed667 Mon Sep 17 00:00:00 2001 From: markbookk Date: Tue, 16 Mar 2021 17:19:01 -0400 Subject: [PATCH 16/20] Changing state from NDArray to NDList --- .../rnn-concise.ipynb | 34 +++++++------ utils/TimeMachineUtils.java | 51 ++++++++++--------- 2 files changed, 46 insertions(+), 39 deletions(-) diff --git a/chapter_recurrent-neural-networks/rnn-concise.ipynb b/chapter_recurrent-neural-networks/rnn-concise.ipynb index a6a52587..29110ef2 100644 --- a/chapter_recurrent-neural-networks/rnn-concise.ipynb +++ b/chapter_recurrent-neural-networks/rnn-concise.ipynb @@ -257,16 +257,16 @@ "\n", " if (net instanceof RNNModelScratch) {\n", " RNNModelScratch castedNet = (RNNModelScratch) net;\n", - " NDArray state = castedNet.beginState(1, device);\n", + " NDList state = castedNet.beginState(1, device);\n", "\n", " for (char c : prefix.substring(1).toCharArray()) { // Warm-up period\n", - " state = (NDArray) castedNet.forward(getInput.apply(), state).getValue();\n", + " state = (NDList) castedNet.forward(getInput.apply(), state).getValue();\n", " outputs.add(vocab.getIdx(\"\" + c));\n", " }\n", "\n", " NDArray y;\n", " for (int i = 0; i < numPreds; i++) {\n", - " Pair pair = castedNet.forward(getInput.apply(), state);\n", + " Pair pair = castedNet.forward(getInput.apply(), state);\n", " y = pair.getKey();\n", " state = pair.getValue();\n", "\n", @@ -274,7 +274,7 @@ " }\n", " } else {\n", " AbstractBlock castedNet = (AbstractBlock) net;\n", - " NDArray state = null;\n", + " NDList state = null;\n", " for (char c : prefix.substring(1).toCharArray()) { // Warm-up period\n", " if (state == null) {\n", " // Begin state\n", @@ -284,15 +284,15 @@ " new ParameterStore(manager, false),\n", " new NDList(getInput.apply()),\n", " false)\n", - " .get(1);\n", + " .subNDList(1);\n", " } else {\n", " state =\n", " castedNet\n", " .forward(\n", " new ParameterStore(manager, false),\n", - " new NDList(getInput.apply(), state),\n", + " new NDList(getInput.apply()).addAll(state),\n", " false)\n", - " .get(1);\n", + " .subNDList(1);\n", " }\n", " outputs.add(vocab.getIdx(\"\" + c));\n", " }\n", @@ -302,10 +302,10 @@ " NDList pair =\n", " castedNet.forward(\n", " new ParameterStore(manager, false),\n", - " new NDList(getInput.apply(), state),\n", + " new NDList(getInput.apply()).addAll(state),\n", " false);\n", " y = pair.get(0);\n", - " state = pair.get(1);\n", + " state = pair.subNDList(1);\n", "\n", " outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));\n", " }\n", @@ -379,8 +379,8 @@ " ppl = pair.getKey();\n", " speed = pair.getValue();\n", " if ((epoch + 1) % 10 == 0) {\n", - " animator.add(epoch + 1, (float) ppl, \"ppl\");\n", - " animator.show();\n", + " animator.add(epoch + 1, (float) ppl, \"ppl\");\n", + " animator.show();\n", " }\n", " }\n", " System.out.format(\n", @@ -411,7 +411,7 @@ " Accumulator metric = new Accumulator(2); // Sum of training loss, no. of tokens\n", "\n", " try (NDManager childManager = manager.newSubManager()) {\n", - " NDArray state = null;\n", + " NDList state = null;\n", " for (Batch batch : dataset.getData(manager)) {\n", " NDArray X = batch.getData().head().toDevice(Functions.tryGpu(0), true);\n", " X.attach(childManager);\n", @@ -426,7 +426,9 @@ " .beginState((int) X.getShape().getShape()[0], device);\n", " }\n", " } else {\n", - " state.stopGradient();\n", + " for (NDArray s : state) {\n", + " s.stopGradient();\n", + " }\n", " }\n", " if (state != null) {\n", " state.attach(childManager);\n", @@ -438,7 +440,7 @@ " try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n", " NDArray yHat;\n", " if (net instanceof RNNModelScratch) {\n", - " Pair pairResult = ((RNNModelScratch) net).forward(X, state);\n", + " Pair pairResult = ((RNNModelScratch) net).forward(X, state);\n", " yHat = pairResult.getKey();\n", " state = pairResult.getValue();\n", " } else {\n", @@ -456,11 +458,11 @@ " ((AbstractBlock) net)\n", " .forward(\n", " new ParameterStore(manager, false),\n", - " new NDList(X, state),\n", + " new NDList(X).addAll(state),\n", " true);\n", " }\n", " yHat = pairResult.get(0);\n", - " state = pairResult.get(1);\n", + " state = pairResult.subNDList(1);\n", " }\n", "\n", " NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean();\n", diff --git a/utils/TimeMachineUtils.java b/utils/TimeMachineUtils.java index 7a5f3117..b8f90f00 100644 --- a/utils/TimeMachineUtils.java +++ b/utils/TimeMachineUtils.java @@ -236,16 +236,16 @@ public class RNNModelScratch { public int vocabSize; public int numHiddens; public NDList params; - public Functions.TriFunction initState; - public Functions.TriFunction forwardFn; + public Functions.TriFunction initState; + public Functions.TriFunction forwardFn; public RNNModelScratch( int vocabSize, int numHiddens, Device device, Functions.TriFunction getParams, - Functions.TriFunction initRNNState, - Functions.TriFunction forwardFn) { + Functions.TriFunction initRNNState, + Functions.TriFunction forwardFn) { this.vocabSize = vocabSize; this.numHiddens = numHiddens; this.params = getParams.apply(vocabSize, numHiddens, device); @@ -253,12 +253,12 @@ public RNNModelScratch( this.forwardFn = forwardFn; } - public Pair forward(NDArray X, NDArray state) { + public Pair forward(NDArray X, NDList state) { X = X.transpose().oneHot(this.vocabSize); return this.forwardFn.apply(X, state, this.params); } - public NDArray beginState(int batchSize, Device device) { + public NDList beginState(int batchSize, Device device) { return this.initState.apply(batchSize, this.numHiddens, device); } } @@ -380,16 +380,16 @@ public static String predictCh8( if (net instanceof RNNModelScratch) { RNNModelScratch castedNet = (RNNModelScratch) net; - NDArray state = castedNet.beginState(1, device); + NDList state = castedNet.beginState(1, device); for (char c : prefix.substring(1).toCharArray()) { // Warm-up period - state = (NDArray) castedNet.forward(getInput.apply(), state).getValue(); + state = (NDList) castedNet.forward(getInput.apply(), state).getValue(); outputs.add(vocab.getIdx("" + c)); } NDArray y; for (int i = 0; i < numPreds; i++) { - Pair pair = castedNet.forward(getInput.apply(), state); + Pair pair = castedNet.forward(getInput.apply(), state); y = pair.getKey(); state = pair.getValue(); @@ -397,7 +397,7 @@ public static String predictCh8( } } else { AbstractBlock castedNet = (AbstractBlock) net; - NDArray state = null; + NDList state = null; for (char c : prefix.substring(1).toCharArray()) { // Warm-up period if (state == null) { // Begin state @@ -407,15 +407,15 @@ public static String predictCh8( new ParameterStore(manager, false), new NDList(getInput.apply()), false) - .get(1); + .subNDList(1); } else { state = castedNet .forward( new ParameterStore(manager, false), - new NDList(getInput.apply(), state), + new NDList(getInput.apply()).addAll(state), false) - .get(1); + .subNDList(1); } outputs.add(vocab.getIdx("" + c)); } @@ -425,10 +425,10 @@ public static String predictCh8( NDList pair = castedNet.forward( new ParameterStore(manager, false), - new NDList(getInput.apply(), state), + new NDList(getInput.apply()).addAll(state), false); y = pair.get(0); - state = pair.get(1); + state = pair.subNDList(1); outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L)); } @@ -453,7 +453,7 @@ public static void trainCh8( NDManager manager) throws IOException, TranslateException { SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss(); - Animator animator = new Animator(); + // Animator animator = new Animator(); Functions.voidTwoFunction updater; if (net instanceof RNNModelScratch) { @@ -495,9 +495,12 @@ public static void trainCh8( ppl = pair.getKey(); speed = pair.getValue(); if ((epoch + 1) % 10 == 0) { - animator.add(epoch + 1, (float) ppl, "ppl"); - animator.show(); + // animator.add(epoch + 1, (float) ppl, "ppl"); + // animator.show(); } + System.out.format( + "epoch: %d, perplexity: %.1f, %.1f tokens/sec on %s%n", + epoch, ppl, speed, device.toString()); } System.out.format( "perplexity: %.1f, %.1f tokens/sec on %s%n", ppl, speed, device.toString()); @@ -520,7 +523,7 @@ public static Pair trainEpochCh8( Accumulator metric = new Accumulator(2); // Sum of training loss, no. of tokens try (NDManager childManager = manager.newSubManager()) { - NDArray state = null; + NDList state = null; for (Batch batch : dataset.getData(manager)) { NDArray X = batch.getData().head().toDevice(Functions.tryGpu(0), true); X.attach(childManager); @@ -535,7 +538,9 @@ public static Pair trainEpochCh8( .beginState((int) X.getShape().getShape()[0], device); } } else { - state.stopGradient(); + for (NDArray s : state) { + s.stopGradient(); + } } if (state != null) { state.attach(childManager); @@ -547,7 +552,7 @@ public static Pair trainEpochCh8( try (GradientCollector gc = Engine.getInstance().newGradientCollector()) { NDArray yHat; if (net instanceof RNNModelScratch) { - Pair pairResult = ((RNNModelScratch) net).forward(X, state); + Pair pairResult = ((RNNModelScratch) net).forward(X, state); yHat = pairResult.getKey(); state = pairResult.getValue(); } else { @@ -565,11 +570,11 @@ public static Pair trainEpochCh8( ((AbstractBlock) net) .forward( new ParameterStore(manager, false), - new NDList(X, state), + new NDList(X).addAll(state), true); } yHat = pairResult.get(0); - state = pairResult.get(1); + state = pairResult.subNDList(1); } NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean(); From 418645dfc21ac1be7dafb0c07348612a4214c1f0 Mon Sep 17 00:00:00 2001 From: markbookk Date: Wed, 17 Mar 2021 11:53:37 -0400 Subject: [PATCH 17/20] Changing type of RNNModel to generic to be able to be used with GRU and others --- utils/TimeMachineUtils.java | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/utils/TimeMachineUtils.java b/utils/TimeMachineUtils.java index b8f90f00..bb98418a 100644 --- a/utils/TimeMachineUtils.java +++ b/utils/TimeMachineUtils.java @@ -263,19 +263,20 @@ public NDList beginState(int batchSize, Device device) { } } -public class RNNModel extends AbstractBlock { +public class RNNModel extends AbstractBlock { private static final byte VERSION = 2; - private RNN rnnLayer; + private T rnnLayer; private Linear dense; private int vocabSize; - public RNNModel(RNN rnnLayer, int vocabSize) { + public RNNModel (T rnnLayer, int vocabSize) { super(VERSION); this.rnnLayer = rnnLayer; this.addChildBlock("rnn", rnnLayer); this.vocabSize = vocabSize; this.dense = Linear.builder().setUnits(vocabSize).build(); + this.addChildBlock("linear", dense); } /** {@inheritDoc} */ @@ -453,7 +454,7 @@ public static void trainCh8( NDManager manager) throws IOException, TranslateException { SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss(); - // Animator animator = new Animator(); + Animator animator = new Animator(); Functions.voidTwoFunction updater; if (net instanceof RNNModelScratch) { @@ -495,12 +496,9 @@ public static void trainCh8( ppl = pair.getKey(); speed = pair.getValue(); if ((epoch + 1) % 10 == 0) { - // animator.add(epoch + 1, (float) ppl, "ppl"); - // animator.show(); + animator.add(epoch + 1, (float) ppl, "ppl"); + animator.show(); } - System.out.format( - "epoch: %d, perplexity: %.1f, %.1f tokens/sec on %s%n", - epoch, ppl, speed, device.toString()); } System.out.format( "perplexity: %.1f, %.1f tokens/sec on %s%n", ppl, speed, device.toString()); From 888e4f6433d492c42456f3eadd9e4358dccd9e68 Mon Sep 17 00:00:00 2001 From: markbookk Date: Wed, 24 Mar 2021 10:28:19 -0400 Subject: [PATCH 18/20] Adding code review suggestions; changing mul to muli in gradientClipping; using NDList to close list of NDArrays --- .../rnn-concise.ipynb | 27 +++++++++---------- utils/TimeMachineUtils.java | 4 +-- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/chapter_recurrent-neural-networks/rnn-concise.ipynb b/chapter_recurrent-neural-networks/rnn-concise.ipynb index 29110ef2..0ce24a47 100644 --- a/chapter_recurrent-neural-networks/rnn-concise.ipynb +++ b/chapter_recurrent-neural-networks/rnn-concise.ipynb @@ -174,17 +174,18 @@ " Ys = Ys.reshape(new Shape(batchSize, -1));\n", " int numBatches = (int) Xs.getShape().get(1) / numSteps;\n", "\n", + " NDList xNDList = new NDList();\n", + " NDList yNDList = new NDList();\n", " for (int i = 0; i < numSteps * numBatches; i += numSteps) {\n", " NDArray X = Xs.get(new NDIndex(\":, {}:{}\", i, i + numSteps));\n", " NDArray Y = Ys.get(new NDIndex(\":, {}:{}\", i, i + numSteps));\n", - " // Temp variables to be able to detach NDArray which will be replaced\n", - " NDArray temp = this.data;\n", - " NDArray temp2 = this.data;\n", - " this.data = this.data.concat(X);\n", - " this.labels = this.labels.concat(Y);\n", - " temp.detach();\n", - " temp2.detach();\n", + " xNDList.add(X);\n", + " yNDList.add(Y);\n", " }\n", + " this.data = NDArrays.concat(xNDList);\n", + " xNDList.close();\n", + " this.labels = NDArrays.concat(yNDList);\n", + " yNDList.close();\n", " this.prepared = true;\n", " }\n", "\n", @@ -193,12 +194,10 @@ " }\n", "\n", " public static final class Builder extends BaseBuilder {\n", - "\n", " int numSteps;\n", " int maxTokens;\n", " NDManager manager;\n", "\n", - "\n", " @Override\n", " protected Builder self() { return this; }\n", "\n", @@ -412,11 +411,9 @@ "\n", " try (NDManager childManager = manager.newSubManager()) {\n", " NDList state = null;\n", - " for (Batch batch : dataset.getData(manager)) {\n", + " for (Batch batch : dataset.getData(childManager)) {\n", " NDArray X = batch.getData().head().toDevice(Functions.tryGpu(0), true);\n", - " X.attach(childManager);\n", " NDArray Y = batch.getLabels().head().toDevice(Functions.tryGpu(0), true);\n", - " Y.attach(childManager);\n", " if (state == null || useRandomIter) {\n", " // Initialize `state` when either it is the first iteration or\n", " // using random sampling\n", @@ -503,8 +500,8 @@ " double norm = Math.sqrt(result);\n", " if (norm > theta) {\n", " for (NDArray param : params) {\n", - " NDArray gradient = param.getGradient().stopGradient();\n", - " param.getGradient().set(new NDIndex(\":\"), gradient.mul(theta / norm));\n", + " NDArray gradient = param.getGradient();\n", + " gradient.muli(theta / norm);\n", " }\n", " }\n", "}" @@ -699,7 +696,7 @@ " NDArray Y = result.get(0);\n", " NDArray state = result.get(1);\n", "\n", - " int shapeLength = Y.getShape().getShape().length;\n", + " int shapeLength = Y.getShape().dimension();\n", " NDList output = this.dense.forward(parameterStore, new NDList(Y\n", " .reshape(new Shape(-1, Y.getShape().get(shapeLength-1)))), training);\n", " return new NDList(output.get(0), state);\n", diff --git a/utils/TimeMachineUtils.java b/utils/TimeMachineUtils.java index bb98418a..a475fcd6 100644 --- a/utils/TimeMachineUtils.java +++ b/utils/TimeMachineUtils.java @@ -606,8 +606,8 @@ public static void gradClipping(Object net, int theta, NDManager manager) { double norm = Math.sqrt(result); if (norm > theta) { for (NDArray param : params) { - NDArray gradient = param.getGradient().stopGradient(); - param.getGradient().set(new NDIndex(":"), gradient.mul(theta / norm)); + NDArray gradient = param.getGradient(); + gradient.muli(theta / norm); } } } From 9a302f2741a23115a0df92247bfd168b4c349705 Mon Sep 17 00:00:00 2001 From: markbookk Date: Wed, 24 Mar 2021 11:58:23 -0400 Subject: [PATCH 19/20] Upgrading version to 0.11.0-snapshot of section 8.3 and adding dependencies to TimeMachineUtils.java --- .../language-models-and-dataset.ipynb | 4 ++-- utils/TimeMachineUtils.java | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/chapter_recurrent-neural-networks/language-models-and-dataset.ipynb b/chapter_recurrent-neural-networks/language-models-and-dataset.ipynb index b55f141b..cce4849a 100644 --- a/chapter_recurrent-neural-networks/language-models-and-dataset.ipynb +++ b/chapter_recurrent-neural-networks/language-models-and-dataset.ipynb @@ -142,11 +142,11 @@ "source": [ "%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", "\n", - "%maven ai.djl:api:0.10.0\n", + "%maven ai.djl:api:0.11.0-SNAPSHOT\n", "%maven org.slf4j:slf4j-api:1.7.26\n", "%maven org.slf4j:slf4j-simple:1.7.26\n", "\n", - "%maven ai.djl.mxnet:mxnet-engine:0.10.0\n", + "%maven ai.djl.mxnet:mxnet-engine:0.11.0-SNAPSHOT\n", "%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport" ] }, diff --git a/utils/TimeMachineUtils.java b/utils/TimeMachineUtils.java index a475fcd6..8b5254a5 100644 --- a/utils/TimeMachineUtils.java +++ b/utils/TimeMachineUtils.java @@ -22,6 +22,12 @@ import java.util.*; import java.util.function.*; +%load ../utils/Functions.java +%load ../utils/Animator.java +%load ../utils/Training.java +%load ../utils/StopWatch.java +%load ../utils/Accumulator.java + public class Vocab { public int unk; public List> tokenFreqs; From f85c8f6deffd7c4c7355ac9025bbaa634d11052c Mon Sep 17 00:00:00 2001 From: markbookk Date: Wed, 24 Mar 2021 12:34:52 -0400 Subject: [PATCH 20/20] Updating rnn-scratch notebook to include changes of section 8.6 --- chapter_recurrent-neural-networks/rnn-scratch.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chapter_recurrent-neural-networks/rnn-scratch.ipynb b/chapter_recurrent-neural-networks/rnn-scratch.ipynb index dbbc4b09..626be3b8 100644 --- a/chapter_recurrent-neural-networks/rnn-scratch.ipynb +++ b/chapter_recurrent-neural-networks/rnn-scratch.ipynb @@ -123,7 +123,7 @@ "source": [ "int batchSize = 32;\n", "int numSteps = 35;\n", - "Pair, Vocab> timeMachine = loadDataTimeMachine(batchSize, numSteps, false, 10000);\n", + "Pair, Vocab> timeMachine = SeqDataLoader.loadDataTimeMachine(batchSize, numSteps, false, 10000, manager);\n", "List trainIter = timeMachine.getKey();\n", "Vocab vocab = timeMachine.getValue();" ] @@ -226,7 +226,7 @@ " // Attach gradients\n", " NDList params = new NDList(W_xh, W_hh, b_h, W_hq, b_q);\n", " for (NDArray param : params) {\n", - " param.attachGradient();\n", + " param.setRequiresGradient(true);\n", " }\n", " return params;\n", "}\n",