Skip to content

Commit

Permalink
Fixing bug (learning rate was not set appropiately), adding TimeMachi…
Browse files Browse the repository at this point in the history
…neDataset, cleaning code, and adding documentation
  • Loading branch information
markbookk committed Mar 15, 2021
1 parent 7591ef4 commit 616f5b1
Show file tree
Hide file tree
Showing 2 changed files with 298 additions and 124 deletions.
278 changes: 168 additions & 110 deletions chapter_recurrent-neural-networks/rnn-concise.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -38,7 +38,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -55,7 +55,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -78,26 +78,174 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"NDManager manager = NDManager.newBaseManager(Functions.tryGpu(0));"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Creating a Dataset in DJL"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In DJL, the ideal and concise way of dealing with datasets, is to use the built-in datasets that can easily wrap around existing NDArrays or to create your own dataset that extends from the `RandomAccessDataset` class. For this section, we will be implementing our own. For more information on creating your own dataset in DJL, you can refer to: https://djl.ai/docs/development/how_to_use_dataset.html\n",
"\n",
"Our implementation of `TimeMachineDataset` will be a concise replacement of the `SeqDataLoader` class previously created. Using a dataset in DJL format, will allow us to use already built-in functions so we don't have to implement most things from scratch. We have to implement a Builder, a prepare function which will contain the process to save the data to the TimeMachineDataset object, and finally a get function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"public class TimeMachineDataset extends RandomAccessDataset {\n",
"\n",
" private Vocab vocab;\n",
" private NDArray data;\n",
" private NDArray labels;\n",
" private int numSteps;\n",
" private int maxTokens;\n",
" private int batchSize;\n",
" private NDManager manager;\n",
" private boolean prepared;\n",
"\n",
" public TimeMachineDataset(Builder builder) {\n",
" super(builder);\n",
" this.numSteps = builder.numSteps;\n",
" this.maxTokens = builder.maxTokens;\n",
" this.batchSize = builder.getSampler().getBatchSize();\n",
" this.manager = builder.manager;\n",
" this.data = this.manager.create(new Shape(0,35), DataType.INT32);\n",
" this.labels = this.manager.create(new Shape(0,35), DataType.INT32);\n",
" this.prepared = false;\n",
" }\n",
"\n",
" @Override\n",
" public Record get(NDManager manager, long index) throws IOException {\n",
" NDArray X = data.get(new NDIndex(\"{}\", index));\n",
" NDArray Y = labels.get(new NDIndex(\"{}\", index));\n",
" return new Record(new NDList(X), new NDList(Y));\n",
" }\n",
"\n",
" @Override\n",
" protected long availableSize() {\n",
" return data.getShape().get(0);\n",
" }\n",
"\n",
" @Override\n",
" public void prepare(Progress progress) throws IOException, TranslateException {\n",
" if (prepared) {\n",
" return;\n",
" }\n",
"\n",
" Pair<List<Integer>, Vocab> corpusVocabPair = null;\n",
" try {\n",
" corpusVocabPair = TimeMachine.loadCorpusTimeMachine(maxTokens);\n",
" } catch (Exception e) {\n",
" e.printStackTrace(); // Exception can be from unknown token type during tokenize() function.\n",
" }\n",
" List<Integer> corpus = corpusVocabPair.getKey();\n",
" this.vocab = corpusVocabPair.getValue();\n",
"\n",
" // Start with a random offset (inclusive of `numSteps - 1`) to partition a\n",
" // sequence\n",
" int offset = new Random().nextInt(numSteps);\n",
" int numTokens = ((int) ((corpus.size() - offset - 1) / batchSize)) * batchSize;\n",
" NDArray Xs =\n",
" manager.create(\n",
" corpus.subList(offset, offset + numTokens).stream()\n",
" .mapToInt(Integer::intValue)\n",
" .toArray());\n",
" NDArray Ys =\n",
" manager.create(\n",
" corpus.subList(offset + 1, offset + 1 + numTokens).stream()\n",
" .mapToInt(Integer::intValue)\n",
" .toArray());\n",
" Xs = Xs.reshape(new Shape(batchSize, -1));\n",
" Ys = Ys.reshape(new Shape(batchSize, -1));\n",
" int numBatches = (int) Xs.getShape().get(1) / numSteps;\n",
"\n",
" for (int i = 0; i < numSteps * numBatches; i += numSteps) {\n",
" NDArray X = Xs.get(new NDIndex(\":, {}:{}\", i, i + numSteps));\n",
" NDArray Y = Ys.get(new NDIndex(\":, {}:{}\", i, i + numSteps));\n",
" // Temp variables to be able to detach NDArray which will be replaced\n",
" NDArray temp = this.data;\n",
" NDArray temp2 = this.data;\n",
" this.data = this.data.concat(X);\n",
" this.labels = this.labels.concat(Y);\n",
" temp.detach();\n",
" temp2.detach();\n",
" }\n",
" this.prepared = true;\n",
" }\n",
"\n",
" public Vocab getVocab() {\n",
" return this.vocab;\n",
" }\n",
"\n",
" public static final class Builder extends BaseBuilder<Builder> {\n",
"\n",
" int numSteps;\n",
" int maxTokens;\n",
" NDManager manager;\n",
"\n",
"\n",
" @Override\n",
" protected Builder self() { return this; }\n",
"\n",
" public Builder setSteps(int steps) {\n",
" this.numSteps = steps;\n",
" return this;\n",
" }\n",
"\n",
" public Builder setMaxTokens(int maxTokens) {\n",
" this.maxTokens = maxTokens;\n",
" return this;\n",
" }\n",
"\n",
" public Builder setManager(NDManager manager) {\n",
" this.manager = manager;\n",
" return this;\n",
" }\n",
"\n",
" public TimeMachineDataset build() throws IOException, TranslateException {\n",
" TimeMachineDataset dataset = new TimeMachineDataset(this);\n",
" return dataset;\n",
" }\n",
" }\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we will leverage the dataset that we just created and assign the required parameters."
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"int batchSize = 32;\n",
"int numSteps = 35;\n",
"\n",
"Pair<ArrayList<NDList>, Vocab> timeMachine =\n",
" SeqDataLoader.loadDataTimeMachine(batchSize, numSteps, false, 10000, manager);\n",
"ArrayList<NDList> trainIter = timeMachine.getKey();\n",
"Vocab vocab = timeMachine.getValue();"
"TimeMachineDataset dataset = new TimeMachineDataset.Builder()\n",
" .setManager(manager).setMaxTokens(10000).setSampling(batchSize, false)\n",
" .setSteps(numSteps).build();\n",
"dataset.prepare();\n",
"Vocab vocab = dataset.getVocab();"
]
},
{
Expand All @@ -116,7 +264,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -151,18 +299,9 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"(1, 32, 256)\n"
]
}
],
"outputs": [],
"source": [
"public static NDList beginState(int batchSize, int numLayers, int numHiddens) {\n",
" return new NDList(manager.zeros(new Shape(numLayers, batchSize, numHiddens)));\n",
Expand Down Expand Up @@ -219,18 +358,9 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(35, 32, 256)\n",
"(1, 32, 256)\n"
]
}
],
"outputs": [],
"source": [
"NDArray X = manager.randomUniform (0, 1,new Shape(numSteps, batchSize, vocab.length()));\n",
"\n",
Expand All @@ -256,7 +386,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -276,7 +406,6 @@
" this.addChildBlock(\"linear\", dense);\n",
" }\n",
"\n",
" \n",
" @Override\n",
" protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {\n",
" NDArray X = inputs.get(0).transpose().oneHot(this.vocabSize);\n",
Expand Down Expand Up @@ -318,17 +447,9 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time travellerddjjjjjdjj\n"
]
}
],
"outputs": [],
"source": [
"Device device = Functions.tryGpu(0);\n",
"RNNModel net = new RNNModel(rnnLayer, vocab.length());\n",
Expand All @@ -348,76 +469,13 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Training on: cpu().\n",
"[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Load MXNet Engine Version 1.7.0 in 0.069 ms.\n"
]
},
{
"data": {
"text/html": [
"<img id=\"0e286de56acc4682a793e55bdda70300_img\"></img>\n",
"<div id=\"0e286de56acc4682a793e55bdda70300\"></div>\n",
"<script>require(['https://cdn.plot.ly/plotly-1.57.0.min.js'], Plotly => {\n",
"var target_0e286de56acc4682a793e55bdda70300 = document.getElementById('0e286de56acc4682a793e55bdda70300');\n",
"var layout = {\n",
" height: 600,\n",
" width: 800,\n",
" showlegend: true,\n",
" xaxis: {\n",
" title: 'epoch',\n",
" },\n",
"\n",
" yaxis: {\n",
" title: 'value',\n",
" },\n",
"\n",
"};\n",
"\n",
"var trace0 =\n",
"{\n",
"x: [\"10.0\",\"20.0\",\"30.0\",\"40.0\",\"50.0\",\"60.0\",\"70.0\",\"80.0\",\"90.0\",\"100.0\",\"110.0\",\"120.0\",\"130.0\",\"140.0\",\"150.0\",\"160.0\",\"170.0\",\"180.0\",\"190.0\",\"200.0\",\"210.0\",\"220.0\",\"230.0\",\"240.0\",\"250.0\",\"260.0\",\"270.0\",\"280.0\",\"290.0\",\"300.0\",\"310.0\",\"320.0\",\"330.0\",\"340.0\",\"350.0\",\"360.0\",\"370.0\",\"380.0\",\"390.0\",\"400.0\",\"410.0\",\"420.0\",\"430.0\",\"440.0\",\"450.0\",\"460.0\",\"470.0\",\"480.0\",\"490.0\",\"500.0\"],\n",
"y: [\"17.150963\",\"15.537794\",\"14.16872\",\"13.373119\",\"12.640522\",\"12.03883\",\"11.456105\",\"11.028135\",\"10.552734\",\"10.270245\",\"9.92747\",\"9.746132\",\"9.389429\",\"9.133248\",\"8.936856\",\"8.701342\",\"8.559859\",\"8.407107\",\"8.2456665\",\"8.143786\",\"7.9912767\",\"7.8197293\",\"7.717286\",\"7.5954113\",\"7.443136\",\"7.4339175\",\"7.2851996\",\"7.121727\",\"7.1118436\",\"6.959592\",\"6.8425007\",\"6.785958\",\"6.797894\",\"6.6693645\",\"6.582549\",\"6.4133563\",\"6.3121667\",\"6.326052\",\"6.2598214\",\"6.2059193\",\"6.048262\",\"6.001453\",\"6.0483203\",\"5.883466\",\"5.8082557\",\"5.794123\",\"5.7899036\",\"5.6908445\",\"5.5807066\",\"5.561425\"],\n",
"showlegend: true,\n",
"mode: 'lines',\n",
"xaxis: 'x',\n",
"yaxis: 'y',\n",
"type: 'scatter',\n",
"name: 'ppl',\n",
"};\n",
"\n",
"\n",
"var data = [ trace0];\n",
"Plotly.newPlot(target_0e286de56acc4682a793e55bdda70300, data, layout);\n",
"})</script>\n"
],
"text/plain": [
"tech.tablesaw.plotly.components.Figure@4ed507e1"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"perplexity: 5.6, 69994.2 tokens/sec on cpu()\n",
"time traveller ou there thoe int and thene ta the this thaw ol \n",
"travellertand mh the thme travel er oae sooeee te tou than \n"
]
}
],
"outputs": [],
"source": [
"int numEpochs = 500;\n",
"int lr = 1;\n",
"TimeMachine.trainCh8((Object) net, trainIter, vocab, lr, numEpochs, device, false, manager);"
"TimeMachine.trainCh8((Object) net, dataset, vocab, lr, numEpochs, device, false, manager);"
]
},
{
Expand Down
Loading

0 comments on commit 616f5b1

Please sign in to comment.