Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add NumpyIter #114

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/python/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Code Examples

Python API Documents
--------------------
* [High Level Model Training Related API](model.md)
* [NDArray API](ndarray.md)
* [Symbolic API](symbol.md)
* [Data Loading API](io.md)
* [KVStore API](kvstore.md)
* [Data Loading API](io.md)
117 changes: 117 additions & 0 deletions doc/python/model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
MXNet Python Model API
======================
The model API in mxnet as not really an API.
It is a thin wrapper build on top of [ndarray](ndarray.md) and [symbolic](symbol.md)
modules to make neural network training easy.

* [Train a Model](#train-a-model) introduces basic training.
* [Save the Model](#save-the-model)
* [Periodically Checkpoint](#periodically-checkpoint)
* [Initializer API Reference](#initializer-api-reference)
* [Evaluation Metric API Reference](#initializer-api-reference)
* [Optimizer API Reference](#optimizer-api-reference)

Train a Model
-------------
To train a model, you can follow two steps, first a configuration using symbol,
then call ```model.Feedforward.create``` to create a model for you.
The following example creates a two layer neural networks.

```python
# configure a two layer neuralnetwork
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type='relu')
fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=64)
softmax = mx.symbol.Softmax(fc2, name='sm')
# create a model
model = mx.model.FeedForward.create(
softmax,
X=data_set,
num_round=num_round,
learning_rate=0.01)
```

You can also use scikit-learn style construct and fit function to create a model.
For more information, you can refer to [Model API Reference](#model-api-reference).

Save the Model
--------------
It is important to save your work after the job done.
To save the model, you can directly pickle it if you like the pythonic way.
We also provide a save and load function.

```python
# save a model to mymodel-symbol.json and mymodel-0100.params
prefix = 'mymodel'
iteration = 100
model.save(prefix, iteration)

# load model back
model_loaded = mx.model.FeedForward.load(prefix, iteration)
```
The advantage of this save and load function is they are language agnostic,
and you should be able to save and load directly into cloud storage such as S3 and HDFS.

Periodically Checkpoint
-----------------------
It is also helpful to periodically checkpoint your model after each iteration.
To do so, you can simply add a checkpoint callback ```do_checkpoint(path)``` to the function.
The training process will automatically checkpoint to the specified place after
each iteration.

```python
prefix='models/chkpt'
model = mx.model.FeedForward.create(
softmax,
X=data_set,
iter_end_callback=mx.model.do_checkpoint(prefix),
...)
```
You can load the model checkpoint later using ```Feedforward.load```.

Use Multiple Devices
--------------------
Simply set ```ctx``` to be the list of devices you like to train on.

```python
devices = [mx.gpu(i) for i in range(num_device)]
model = mx.model.FeedForward.create(
softmax,
X=dataset,
ctx=devices,
...)
```
The training will be done in a data parallel way on the GPUs you specified.

Initializer API Reference
-------------------------

```eval_rst
.. automodule:: mxnet.initializer
:members:
```

Evaluation Metric API Reference
-------------------------------

```eval_rst
.. automodule:: mxnet.metric
:members:
```

Optimizer API Reference
-----------------------

```eval_rst
.. automodule:: mxnet.optimizer
:members:
```

Model API Reference
-------------------

```eval_rst
.. automodule:: mxnet.model
:members:
```
5 changes: 3 additions & 2 deletions doc/python/symbol.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ MXNet Python Symbolic API
* [How to Commpose Symbols](#overloaded-operators) introduces operator overloading of symbols
* [Serialization](#serialization) introduces how to save and load symbols.
* [Multiple Outputs](#multiple-outputs) introduces how to configure multiple outputs
* [API Reference](#api-reference) gives reference to all functions.
* [Symbol Object Document](#mxnet.symbol.Symbol) gives API reference to the Symbol Object.
* [Symbol Creation API Reference](#symbol-creationapi-reference) gives reference to all functions.
* [Symbol Object Document](#mxnet.symbol.Symbol) gives API reference to the Symbol Object
* [Execution API Reference](#execution-api-reference) tell us on what executor can do.

How to Compose Symbols
----------------------
Expand Down
86 changes: 86 additions & 0 deletions example/cifar10/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@

Machine: Dual Xeon E5-1650 3.5GHz, 4 GTX 980, Cuda 6.5

run `cifar.py`:


| | 1 GPU | 2 GPUs | 4 GPUs |
| --- | --- | --- | --- |
| cxxnet | 362 img/sec | 675 img/sec | 1282 img/sec |
| mxnet | 420 img/sec | 804 img/sec | 1436 img/sec |

sample output

```
~/mxnet/example/cifar10 $ python cifar10.py
INFO:root:Start training with 4 devices
Start training with 4 devices
INFO:root:Iteration[0] Train-accuracy=0.507613
Iteration[0] Train-accuracy=0.507613
INFO:root:Iteration[0] Time cost=34.800
Iteration[0] Time cost=34.800
INFO:root:Iteration[0] Validation-accuracy=0.641021
Iteration[0] Validation-accuracy=0.641021
INFO:root:Iteration[1] Train-accuracy=0.679408
Iteration[1] Train-accuracy=0.679408
INFO:root:Iteration[1] Time cost=34.481
Iteration[1] Time cost=34.481
INFO:root:Iteration[1] Validation-accuracy=0.720152
Iteration[1] Validation-accuracy=0.720152
INFO:root:Iteration[2] Train-accuracy=0.740825
Iteration[2] Train-accuracy=0.740825
INFO:root:Iteration[2] Time cost=34.463
Iteration[2] Time cost=34.463
INFO:root:Iteration[2] Validation-accuracy=0.755709
Iteration[2] Validation-accuracy=0.755709
```

results from cxxnet for reference

```
CXXNET Result:
step1: wmat_lr = 0.05, bias_lr = 0.1, mom = 0.9
[1] train-error:0.452865 val-error:0.3614
[2] train-error:0.280231 val-error:0.2504
[3] train-error:0.220968 val-error:0.2456
[4] train-error:0.18746 val-error:0.2145
[5] train-error:0.165221 val-error:0.1796
[6] train-error:0.150056 val-error:0.1786
[7] train-error:0.134571 val-error:0.157
[8] train-error:0.122582 val-error:0.1429
[9] train-error:0.113891 val-error:0.1398
[10] train-error:0.106458 val-error:0.1469
[11] train-error:0.0985054 val-error:0.1447
[12] train-error:0.0953684 val-error:0.1494
[13] train-error:0.0872962 val-error:0.1311
[14] train-error:0.0832401 val-error:0.1544
[15] train-error:0.0773857 val-error:0.1268
[16] train-error:0.0743087 val-error:0.125
[17] train-error:0.0714114 val-error:0.1189
[18] train-error:0.066616 val-error:0.1424
[19] train-error:0.0651175 val-error:0.1322
[20] train-error:0.0616808 val-error:0.111
step2: lr = 0.01, bias_lr = 0.02, mom = 0.9
[21] train-error:0.033368 val-error:0.0907
[22] train-error:0.0250959 val-error:0.0876
[23] train-error:0.0220388 val-error:0.0867
[24] train-error:0.0195812 val-error:0.0848
[25] train-error:0.0173833 val-error:0.0872
[26] train-error:0.0154052 val-error:0.0878
[27] train-error:0.0141264 val-error:0.0863
[28] train-error:0.0134071 val-error:0.0865
[29] train-error:0.0116688 val-error:0.0878
[30] train-error:0.0106298 val-error:0.0873
step3: lr = 0.001, bias_lr = 0.002, mom = 0.9
[31] train-error:-nan val-error:0.0873
[31] train-error:0.0067735 val-error:0.0859
[32] train-error:0.0049952 val-error:0.0835
[33] train-error:0.00485534 val-error:0.0849
[34] train-error:0.00367647 val-error:0.0839
[35] train-error:0.0034367 val-error:0.0844
[36] train-error:0.00275735 val-error:0.084
[37] train-error:0.00221787 val-error:0.083
[38] train-error:0.00171835 val-error:0.0838
[39] train-error:0.00125879 val-error:0.0833
[40] train-error:0.000699329 val-error:0.0842
```
80 changes: 13 additions & 67 deletions example/cifar10/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,67 +5,13 @@
sys.path.insert(0, "../../python/")
sys.path.append("../../tests/python/common")
# import library
import logging
import mxnet as mx
import get_data
import time
import numpy as np
import copy


"""
CXXNET Result:
step1: wmat_lr = 0.05, bias_lr = 0.1, mom = 0.9
[1] train-error:0.452865 val-error:0.3614
[2] train-error:0.280231 val-error:0.2504
[3] train-error:0.220968 val-error:0.2456
[4] train-error:0.18746 val-error:0.2145
[5] train-error:0.165221 val-error:0.1796
[6] train-error:0.150056 val-error:0.1786
[7] train-error:0.134571 val-error:0.157
[8] train-error:0.122582 val-error:0.1429
[9] train-error:0.113891 val-error:0.1398
[10] train-error:0.106458 val-error:0.1469
[11] train-error:0.0985054 val-error:0.1447
[12] train-error:0.0953684 val-error:0.1494
[13] train-error:0.0872962 val-error:0.1311
[14] train-error:0.0832401 val-error:0.1544
[15] train-error:0.0773857 val-error:0.1268
[16] train-error:0.0743087 val-error:0.125
[17] train-error:0.0714114 val-error:0.1189
[18] train-error:0.066616 val-error:0.1424
[19] train-error:0.0651175 val-error:0.1322
[20] train-error:0.0616808 val-error:0.111
step2: lr = 0.01, bias_lr = 0.02, mom = 0.9
[21] train-error:0.033368 val-error:0.0907
[22] train-error:0.0250959 val-error:0.0876
[23] train-error:0.0220388 val-error:0.0867
[24] train-error:0.0195812 val-error:0.0848
[25] train-error:0.0173833 val-error:0.0872
[26] train-error:0.0154052 val-error:0.0878
[27] train-error:0.0141264 val-error:0.0863
[28] train-error:0.0134071 val-error:0.0865
[29] train-error:0.0116688 val-error:0.0878
[30] train-error:0.0106298 val-error:0.0873
step3: lr = 0.001, bias_lr = 0.002, mom = 0.9
[31] train-error:-nan val-error:0.0873
[31] train-error:0.0067735 val-error:0.0859
[32] train-error:0.0049952 val-error:0.0835
[33] train-error:0.00485534 val-error:0.0849
[34] train-error:0.00367647 val-error:0.0839
[35] train-error:0.0034367 val-error:0.0844
[36] train-error:0.00275735 val-error:0.084
[37] train-error:0.00221787 val-error:0.083
[38] train-error:0.00171835 val-error:0.0838
[39] train-error:0.00125879 val-error:0.0833
[40] train-error:0.000699329 val-error:0.0842
"""
def CalAcc(out, label):
pred = np.argmax(out, axis=1)
return np.sum(pred == label) * 1.0 / out.shape[0]


np.random.seed(1812)

conv_cnt = 1
concat_cnt = 1
pool_cnt = 1
Expand Down Expand Up @@ -133,13 +79,6 @@ def SimpleFactory(data, ch_1x1, ch_3x3):
concat_cnt += 1
return concat

def RandomInit(narray):
in_num = narray.shape[1]
out_num = narray.shape[0]
a = np.sqrt(3.0 / (in_num + out_num))
tmp = mx.nd.array(np.random.uniform(-a, a, narray.shape))
narray[:] = tmp

data = mx.symbol.Variable(name="data")
conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu")
in3a = SimpleFactory(conv1, 32, 32)
Expand All @@ -156,10 +95,14 @@ def RandomInit(narray):
flatten = mx.symbol.Flatten(data=pool, name="flatten1")
fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1")
loss = mx.symbol.Softmax(data=fc, name="loss")

#########################################################

get_data.GetCifar10()
batch_size = 128
epoch = 3
num_gpus = 1

train_dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar_mean.bin",
Expand All @@ -178,11 +121,14 @@ def RandomInit(narray):
preprocess_threads=1)

def test_cifar():
model = mx.model.MXNetModel(ctx=mx.gpu(),
symbol=loss, data=(batch_size, 3, 28, 28),
optimizer="sgd", num_round = epoch, batch_size = batch_size,
learning_rate=0.05, momentum=0.9, weight_decay=0.00001)
model.fit(X=train_dataiter, eval_set=test_dataiter, eval_metric=CalAcc)
logging.basicConfig(level=logging.DEBUG)
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
logging.getLogger('').addHandler(console)
gpus = [mx.gpu(i) for i in range(num_gpus)]
model = mx.model.FeedForward(ctx=gpus, symbol=loss, num_round = epoch,
learning_rate=0.05, momentum=0.9, wd=0.00001)
model.fit(X=train_dataiter, eval_data=test_dataiter)


if __name__ == "__main__":
Expand Down
Loading