From 281ca21107770fd2726b4fe9a8d04a208c32cdd3 Mon Sep 17 00:00:00 2001 From: sneakerkg Date: Tue, 22 Sep 2015 01:59:40 +0800 Subject: [PATCH 01/20] add NumpyIter --- python/mxnet/io.py | 81 +++++++++++++++++++++++++++++++- tests/python/unittest/test_io.py | 13 ++++- 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 62e92bd020d5..537a86bd573b 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -5,11 +5,13 @@ import ctypes import sys +import numpy as np from .base import _LIB from .base import c_array, c_str, mx_uint, py_str from .base import DataIterHandle, NDArrayHandle from .base import check_call from .ndarray import NDArray +from .ndarray import array class DataIter(object): """DataIter object in mxnet. List all the needed functions here. """ @@ -83,6 +85,84 @@ def getlabel(self): check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl))) return NDArray(hdl, False) +class NumpyIter(DataIter): + """NumpyIter object in mxnet. Taking Numpy Array into dataiter. """ + + def __init__(self, *args, **kwargs): + """Initialize with handle + + Parameters + ---------- + handle : DataIterHandle + the handle to the underlying C++ Data Iterator + """ + self.data = args[0] + self.label = args[1] + self.batch_size = kwargs.get('batch_size', 100) + self.data_pad = kwargs.get('data_pad', 0) + self.label_pad = kwargs.get('label_pad', 0) + self.loc = 0 + self.out_data = None + self.out_label = None + + def __del__(self): + pass + + def reset(self): + """set loc to 0 + + """ + self.loc = 0 + + def iter_next(self): + """iterate to next data with return value + + Returns + ------- + return true if success + """ + if self.loc < self.data.shape[0]: + batch_data_shape = [] + batch_data_shape.append(self.batch_size) + for i in range(1,len(self.data.shape)): + batch_data_shape.append(self.data.shape[i]) + self.out_data = np.ones(batch_data_shape, dtype=self.data.dtype) * self.label_pad + self.out_label = np.ones([self.batch_size, 1], dtype=self.data.dtype) * self.label_pad + actual_size = min(self.data.shape[0] - self.loc, self.batch_size) + self.out_data[0:actual_size,::] = self.data[self.loc:self.loc+actual_size,::] + self.out_label[0:actual_size,::] = self.label[self.loc:self.loc+actual_size,::] + self.loc += actual_size + return True + else: + return False + + def next(self): + """get next data batch from iterator + + Returns + ------- + labels and images for the next batch + """ + if self.iter_next(): + return self.getdata(), self.getlabel() + else: + raise StopIteration + + # make it work for both python2 and 3 + __next__ = next + + def getdata(self): + """get data from batch + + """ + return array(self.out_data) + + def getlabel(self): + """get label from batch + + """ + return array(self.out_label) + def _make_io_iterator(handle): """Create an io iterator by handle.""" name = ctypes.c_char_p() @@ -156,7 +236,6 @@ def creator(*args, **kwargs): creator.__doc__ = doc_str return creator - def _init_io_module(): """List and add all the data iterators to current module.""" plist = ctypes.POINTER(ctypes.c_void_p)() diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 23d2afc18c03..71164cd12754 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -60,6 +60,17 @@ def test_Cifar10Rec(): for i in range(10): assert(labelcount[i] == 5000) +def test_NumpyIter(): + datas = np.ones([1000,100]) + labels = np.ones([1000, 1]) + for i in range(1000): + datas[i] = i / 100 + labels[i] = i / 100 + dataiter = mx.io.NumpyIter(datas, labels, batch_size=100) + for data, label in dataiter: + print data.asnumpy().flatten() + if __name__ == "__main__": - test_MNISTIter() + test_NumpyIter() + #test_MNISTIter() #test_Cifar10Rec() From 965d0583bf873e57cc09347bbb80fab1dfc70d91 Mon Sep 17 00:00:00 2001 From: sneakerkg Date: Tue, 22 Sep 2015 02:25:40 +0800 Subject: [PATCH 02/20] update old code --- python/mxnet/io.py | 15 ++++----------- tests/python/unittest/test_io.py | 15 ++++++++------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index f95d9657814b..b0c6abfc1842 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -9,7 +9,7 @@ from .base import _LIB from .base import c_array, c_str, mx_uint, py_str from .base import DataIterHandle, NDArrayHandle -from .base import check_call +from .base import check_call, ctypes2docstring from .ndarray import NDArray from .ndarray import array @@ -180,24 +180,17 @@ def _make_io_iterator(handle): ctypes.byref(arg_types), \ ctypes.byref(arg_descs))) iter_name = py_str(name.value) - param_str = [] - for i in range(num_args.value): - ret = '%s : %s' % (arg_names[i], arg_types[i]) - if len(arg_descs[i]) != 0: - ret += '\n ' + py_str(arg_descs[i]) - param_str.append(ret) + param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs) doc_str = ('%s\n\n' + - 'Parameters\n' + - '----------\n' + '%s\n' + 'name : string, required.\n' + ' Name of the resulting data iterator.\n\n' + 'Returns\n' + '-------\n' + - 'iterator: Iterator\n'+ + 'iterator: DataIter\n'+ ' The result iterator.') - doc_str = doc_str % (desc.value, '\n'.join(param_str)) + doc_str = doc_str % (desc.value, param_str) def creator(*args, **kwargs): """Create an iterator. diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 42b154f8f43b..8cbcab112c07 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -15,7 +15,7 @@ def test_MNISTIter(): train_dataiter = mx.io.MNISTIter( image="data/train-images-idx3-ubyte", label="data/train-labels-idx1-ubyte", - input_shape=(784,), + data_shape=(784,), batch_size=batch_size, shuffle=1, flat=1, silent=0, seed=10) # test_loop nbatch = 60000 / batch_size @@ -44,10 +44,10 @@ def test_Cifar10Rec(): rand_crop=False, and_mirror=False, shuffle=False, - input_shape=(3,28,28), + data_shape=(3,28,28), batch_size=100, - nthread=4, - prefetch_capacity=1) + preprocess_threads=4, + prefetch_buffer=1) labelcount = [0 for i in range(10)] batchcount = 0 for data, label in dataiter: @@ -61,7 +61,7 @@ def test_Cifar10Rec(): assert(labelcount[i] == 5000) def test_NumpyIter(): - datas = np.ones([1000,100]) + datas = np.ones([1000, 2, 2]) labels = np.ones([1000, 1]) for i in range(1000): datas[i] = i / 100 @@ -69,6 +69,7 @@ def test_NumpyIter(): dataiter = mx.io.NumpyIter(datas, labels, batch_size=100) batchidx = 0 for data, label in dataiter: + print data.asnumpy() assert(label.asnumpy().flatten().sum() == batchidx * 100) batchidx += 1 dataiter.reset() @@ -82,6 +83,6 @@ def test_NumpyIter(): batchidx += 1 if __name__ == "__main__": - test_NumpyIter() + #test_NumpyIter() #test_MNISTIter() - #test_Cifar10Rec() + test_Cifar10Rec() From 0e3f93d9d1a1c4ff5d5915c8d7ceb1c0ae47354d Mon Sep 17 00:00:00 2001 From: sneakerkg Date: Tue, 22 Sep 2015 23:09:34 +0800 Subject: [PATCH 03/20] refactor NumpyIter --- python/mxnet/io.py | 156 +++++++++++++++++++------------ tests/python/unittest/test_io.py | 24 ++--- 2 files changed, 109 insertions(+), 71 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index b0c6abfc1842..494bc793dd18 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -1,4 +1,5 @@ # coding: utf-8 +# pylint: disable=invalid-name, protected-access, fixme, too-many-arguments """NDArray interface of mxnet""" from __future__ import absolute_import @@ -14,20 +15,19 @@ from .ndarray import array class DataIter(object): - """DataIter object in mxnet. List all the needed functions here. """ + """DataIter object in mxnet. """ - def __init__(self, handle): - """Initialize with handle + def __init__(self): + """constructor of dataiter - Parameters - ---------- - handle : DataIterHandle - the handle to the underlying C++ Data Iterator """ - self.handle = handle + pass def __del__(self): - check_call(_LIB.MXDataIterFree(self.handle)) + """destructor of dataiter + + """ + pass def __iter__(self): """make the class iterable @@ -36,10 +36,10 @@ def __iter__(self): return self def reset(self): - """set loc to 0 + """reset the iter """ - check_call(_LIB.MXDataIterBeforeFirst(self.handle)) + pass def next(self): """get next data batch from iterator @@ -48,12 +48,7 @@ def next(self): ------- labels and images for the next batch """ - next_res = ctypes.c_int(0) - check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res))) - if next_res.value: - return self.getdata(), self.getlabel() - else: - raise StopIteration + pass # make it work for both python2 and 3 __next__ = next @@ -65,50 +60,69 @@ def iter_next(self): ------- return true if success """ - next_res = ctypes.c_int(0) - check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res))) - return next_res.value + pass def getdata(self): """get data from batch + Returns + ------- + data ndarray for the next batch """ - hdl = NDArrayHandle() - check_call(_LIB.MXDataIterGetData(self.handle, ctypes.byref(hdl))) - return NDArray(hdl, False) + pass def getlabel(self): """get label from batch + Returns + ------- + label ndarray for the next batch """ - hdl = NDArrayHandle() - check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl))) - return NDArray(hdl, False) + pass class NumpyIter(DataIter): - """NumpyIter object in mxnet. Taking Numpy Array into dataiter. """ + """NumpyIter object in mxnet. Taking Numpy Array to get dataiter. """ - def __init__(self, *args, **kwargs): + def __init__(self, data, label, batch_size, shuffle=True, data_pad=0, label_pad=0): """Initialize with handle Parameters ---------- - handle : DataIterHandle - the handle to the underlying C++ Data Iterator + data : numpy.array + Numpy ndarray for data + label : numpy.array + Numpy ndarray for label + batch_size: int + Batch Size + shuffle: bool + Whether to shuffle the data + data_pad: float + padding value for data + label_pad: float + padding value for label """ - super(NumpyIter, self).__init__(None) - self.data = args[0] - self.label = args[1] - self.batch_size = kwargs.get('batch_size', 100) - self.data_pad = kwargs.get('data_pad', 0) - self.label_pad = kwargs.get('label_pad', 0) + super(NumpyIter, self).__init__() + self.data = data + self.label = label + self.batch_size = batch_size + self.shuffle = shuffle + self.data_pad = data_pad + self.label_pad = label_pad + # shuffle data + if self.shuffle: + idx = np.arange(self.data.shape[0]) + np.random.shuffle(idx) + new_data = np.zeros(self.data.shape) + new_label = np.zeros(self.label.shape) + for i in range(self.data.shape[0]): + new_data[i] = self.data[idx[i]] + new_label[i] = self.label[idx[i]] + self.data = new_data + self.label = new_label self.loc = 0 self.out_data = None self.out_label = None - def __del__(self): - pass - def reset(self): """set loc to 0 @@ -116,12 +130,6 @@ def reset(self): self.loc = 0 def iter_next(self): - """iterate to next data with return value - - Returns - ------- - return true if success - """ if self.loc < self.data.shape[0]: batch_data_shape = [] batch_data_shape.append(self.batch_size) @@ -138,31 +146,59 @@ def iter_next(self): return False def next(self): - """get next data batch from iterator - - Returns - ------- - labels and images for the next batch - """ if self.iter_next(): return self.getdata(), self.getlabel() else: raise StopIteration - # make it work for both python2 and 3 - __next__ = next - def getdata(self): - """get data from batch - - """ return array(self.out_data) def getlabel(self): - """get label from batch + return array(self.out_label) + +class MXDataIter(DataIter): + """DataIter object in mxnet. List all the needed functions here. """ + + def __init__(self, handle): + """Initialize with handle + Parameters + ---------- + handle : DataIterHandle + the handle to the underlying C++ Data Iterator """ - return array(self.out_label) + super(MXDataIter, self).__init__() + self.handle = handle + + def __del__(self): + check_call(_LIB.MXDataIterFree(self.handle)) + + def reset(self): + check_call(_LIB.MXDataIterBeforeFirst(self.handle)) + + def next(self): + next_res = ctypes.c_int(0) + check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res))) + if next_res.value: + return self.getdata(), self.getlabel() + else: + raise StopIteration + + def iter_next(self): + next_res = ctypes.c_int(0) + check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res))) + return next_res.value + + def getdata(self): + hdl = NDArrayHandle() + check_call(_LIB.MXDataIterGetData(self.handle, ctypes.byref(hdl))) + return NDArray(hdl, False) + + def getlabel(self): + hdl = NDArrayHandle() + check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl))) + return NDArray(hdl, False) def _make_io_iterator(handle): """Create an io iterator by handle.""" @@ -224,7 +260,7 @@ def creator(*args, **kwargs): if len(args): raise TypeError('%s can only accept keyword arguments' % iter_name) - return DataIter(iter_handle) + return MXDataIter(iter_handle) creator.__name__ = iter_name creator.__doc__ = doc_str diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 8cbcab112c07..f526ec09cbaa 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -66,23 +66,25 @@ def test_NumpyIter(): for i in range(1000): datas[i] = i / 100 labels[i] = i / 100 - dataiter = mx.io.NumpyIter(datas, labels, batch_size=100) + dataiter = mx.io.NumpyIter(datas, labels, 128, True) batchidx = 0 for data, label in dataiter: - print data.asnumpy() - assert(label.asnumpy().flatten().sum() == batchidx * 100) batchidx += 1 + assert(batchidx == 8) dataiter.reset() batchidx = 0 - for i in range(1000): - datas[i] = i / 100 - labels[i] = i / 100 - dataiter = mx.io.NumpyIter(datas, labels, batch_size=100) + labelcount = [0 for i in range(10)] for data, label in dataiter: - assert(label.asnumpy().flatten().sum() == batchidx * 100) - batchidx += 1 + label = label.asnumpy().flatten() + for i in range(label.shape[0]): + labelcount[int(label[i])] += 1 + for i in range(10): + if i == 0: + assert(labelcount[i] == 124) + else: + assert(labelcount[i] == 100) if __name__ == "__main__": - #test_NumpyIter() + test_NumpyIter() #test_MNISTIter() - test_Cifar10Rec() + #test_Cifar10Rec() From 60fd2d4a36c89404c9de0d10144e6e98065effd4 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 11:13:39 -0700 Subject: [PATCH 04/20] Make model training multiple device --- python/mxnet/model.py | 218 +++++++++++++++++++++++++++++------------- 1 file changed, 152 insertions(+), 66 deletions(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 2c4614b1dade..1fd73eb4033e 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -9,6 +9,7 @@ from . import symbol as sym from . import optimizer as opt from . import metric +from . import kvstore from .context import Context, cpu from .initializer import Xavier @@ -74,12 +75,54 @@ def _check_arguments(symbol): return (data_index, label_index) -def _train(symbol, ctx, input_shape, - arg_params, aux_params, - begin_round, end_round, optimizer, - train_data, eval_data=None, eval_metric=None, - iter_end_callback=None, logger=None): - """Inernal training function. +def _split_input_slice(input_shape, num_split): + """Get input slice from the input shape. + + Parameters + ---------- + input_shape : tuple + The input shape of the net. + + num_split : int + The number of split we want to have. + + Returns + ------- + slices : list of slice + The split slices to get a specific slice. + + shapes : list of tuples + The shape of each split slice. + + Raises + ------ + ValueError + If there are two many splits such that some slice can be empty. + """ + batch_size = input_shape[0] + step = (batch_size + num_split - 1) / num_split + slices = [] + shapes = [] + for k in range(num_split): + begin = min(k * step, batch_size) + end = min((k+1) * step, batch_size) + if begin == end: + raise ValueError('Too many slices such that some splits are empty') + slices.append(slice(begin, end)) + s = list(input_shape) + s[0] = end - begin + shapes.append(tuple(s)) + return (slices, shapes) + + +def _train_multi_device(symbol, ctx, input_shape, + arg_params, aux_params, + begin_round, end_round, optimizer, + train_data, eval_data=None, eval_metric=None, + iter_end_callback=None, logger=None): + """Internal training function on multiple devices. + + This function will also work for single device as well. Parameters ---------- @@ -127,80 +170,121 @@ def _train(symbol, ctx, input_shape, ----- This function will inplace update the NDArrays in arg_parans and aux_states. """ - assert(len(ctx) == 1) if logger is None: logger = logging - # bind the symbol - train_exec = symbol.simple_bind(ctx[0], data=input_shape, grad_req='write') + # preparation + num_device = len(ctx) + logging.info('Start training with %d devices', num_device) + + slices, shapes = _split_input_slice(input_shape, num_device) + train_execs = [symbol.simple_bind(ctx=c, data=s, grad_req='write') + for c, s in zip(ctx, shapes)] arg_names = symbol.list_arguments() aux_names = symbol.list_auxiliary_states() - arg_arrays = train_exec.arg_arrays - grad_arrays = train_exec.grad_arrays - aux_arrays = train_exec.aux_arrays - # copy initialized parameters to executor parameters - for key, weight in zip(arg_names, arg_arrays): - if key in arg_params: - arg_params[key].copyto(weight) - for key, weight in zip(aux_names, aux_arrays): - if key in aux_params: - aux_params[key].copyto(weight) - # setup helper data structures + # data structure + arg_blocks = [ + [x.arg_arrays[index] for x in train_execs] + for index in range(len(train_execs[0].arg_arrays))] + grad_blocks = [ + [x.grad_arrays[index] for x in train_execs] + for index in range(len(train_execs[0].grad_arrays))] + aux_blocks = [ + [x.aux_arrays[index] for x in train_execs] + for index in range(len(train_execs[0].aux_arrays))] + for name, block in zip(arg_names, arg_blocks): + if name in arg_params: + for w in block: + arg_params[name].copyto(w) + for name, block in zip(aux_names, aux_blocks): + if name in aux_params: + for w in block: + aux_params[name].copyto(w) + # ky value store + kv = kvstore.create() if num_device != 1 else None + # If there are multiple devices, initialize the weights. + for index, pair in enumerate(zip(arg_blocks, grad_blocks)): + arg, grad = pair + if kv and grad[0] is not None: + kv.init(index, arg[0]) + # Input and output data structure data_index, label_index = _check_arguments(symbol) - data_array, label_array = arg_arrays[data_index], arg_arrays[label_index] - out_array = train_exec.outputs[0] - out_cpu_array = nd.zeros(out_array.shape) - arg_blocks = list(zip(arg_arrays, grad_arrays)) - - for i in range(begin_round, end_round): - # training phase + merged_shape = list(train_execs[0].outputs[0].shape) + merged_shape[0] = input_shape[0] + merged_shape = tuple(merged_shape) + out_cpu_array = nd.zeros(merged_shape, cpu()) + + # Now start training + for iteration in range(begin_round, end_round): + # Training phase tic = time.time() train_data.reset() - optimizer.begin_round(i) + optimizer.begin_round(iteration) eval_metric.reset() - + # Iterate over training data. for data, label in train_data: - label.copyto(label_array) - data.copyto(data_array) - train_exec.forward() - out_array.copyto(out_cpu_array) - train_exec.backward() + # Copy data into the target + for target, islice in zip(arg_blocks[label_index], slices): + label[islice].copyto(target) + for target, islice in zip(arg_blocks[data_index], slices): + data[islice].copyto(target) + # forward backward pass + for texec, islice in zip(train_execs, slices): + texec.forward() + texec.outputs[0].copyto(out_cpu_array[islice]) + for texec in train_execs: + texec.backward() # update the parameters - for index, block in enumerate(arg_blocks): - weight, grad = block - if grad is not None: - optimizer.update(index, weight, grad) + for index, pair in enumerate(zip(arg_blocks, grad_blocks)): + arg_list, grad_list = pair + if grad_list[0] is None: + continue + # Gradient synchronization + if kv: + # push gradient + kv.push(index, grad_list) + # pull back the sum, to the same locations. + kv.pull(index, grad_list) + # optimize + for w, g in zip(arg_list, grad_list): + optimizer.update(index, w, g) # evaluate at end, so out_cpu_array can lazy copy eval_metric.update(out_cpu_array, label) name, value = eval_metric.get() - logger.info('Iteration[%d] Train-%s=%f', i, name, value) + logger.info('Iteration[%d] Train-%s=%f', iteration, name, value) toc = time.time() - logger.info('Iteration[%d] Time cost=%.3f', i, (toc - tic)) - - # evaluation phase - if eval_data is not None: + logger.info('Iteration[%d] Time cost=%.3f', iteration, (toc - tic)) + # evaluation + if eval_data: eval_metric.reset() eval_data.reset() for data, label in eval_data: - data.copyto(data_array) - # TODO(bing): add is_train=False - train_exec.forward(is_train=False) - out_array.copyto(out_cpu_array) - eval_metric.update(out_array, label) - + # Copy data into the target + for target, islice in zip(arg_blocks[label_index], slices): + label[islice].copyto(target) + for target, islice in zip(arg_blocks[data_index], slices): + data[islice].copyto(target) + # forward pass + for texec, islice in zip(train_execs, slices): + texec.forward(is_train=False) + texec.outputs[0].copyto(out_cpu_array[islice]) + eval_metric.update(out_cpu_array, label) name, value = eval_metric.get() - logger.info('Iteration[%d] Validation-%s=%f', i, name, value) + logger.info('Iteration[%d] Validation-%s=%f', iteration, name, value) - if iter_end_callback or i + 1 == end_round: + if iter_end_callback or iteration + 1 == end_round: # copy data back to cpu - for key, weight in zip(arg_names, arg_arrays): - if key in arg_params: - weight.copyto(arg_params[key]) - for key, arr in zip(aux_names, aux_arrays): - arr.copyto(aux_params[key]) + for name, block in zip(arg_names, arg_blocks): + if name in arg_params: + weight = sum(w.copyto(cpu()) for w in block) / len(block) + weight.copyto(arg_params[name]) + for name, block in zip(aux_names, aux_blocks): + if name in aux_params: + weight = sum(w.copyto(cpu()) for w in block) / len(block) + weight.copyto(aux_params[name]) if iter_end_callback: - iter_end_callback(i, symbol, arg_params, aux_params) - # end of the function + iter_end_callback(iteration, symbol, arg_params, aux_params) + # end of all iterations return @@ -332,6 +416,8 @@ def __init__(self, symbol, ctx=None, num_round=None, optimizer='sgd', initializer=Xavier(), arg_params=None, aux_params=None, **kwargs): + # check if symbol contain duplicated names. + _check_arguments(symbol) # basic configuration self.symbol = symbol if ctx is None: @@ -467,14 +553,14 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', batch_size = input_shape[0] optimizer = opt.create(optimizer, rescale_grad=(1.0/batch_size), **(self.kwargs)) # do training - _train(self.symbol, self.ctx, input_shape, - self.arg_params, self.aux_params, - begin_round=0, end_round=self.num_round, - optimizer=optimizer, - train_data=X, eval_data=eval_data, - eval_metric=eval_metric, - iter_end_callback=iter_end_callback, - logger=logger) + _train_multi_device(self.symbol, self.ctx, input_shape, + self.arg_params, self.aux_params, + begin_round=0, end_round=self.num_round, + optimizer=optimizer, + train_data=X, eval_data=eval_data, + eval_metric=eval_metric, + iter_end_callback=iter_end_callback, + logger=logger) def save(self, prefix, iteration=None): """Checkpoint the model checkpoint into file. From 21981a8106f6a4758d3c788417dea4021a5ea557 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 12:14:24 -0700 Subject: [PATCH 05/20] Allow partial positional arguments of input symbol --- src/symbol/symbol.cc | 29 ++++++++++++++++++----------- tests/python/train/test_mlp.py | 12 ++++++------ 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 2b923cebc0d9..fb2377dbb6b2 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -246,6 +246,16 @@ Symbol Symbol::operator[] (size_t index) const { } } +// create a default variable name +inline std::string DefaultVarName(const std::string &op_name, + const std::string &arg_name) { + if (op_name.length() == 0) { + return arg_name; + } else { + return op_name + '_' + arg_name; + } +} + void Symbol::Compose(const std::vector& args, const std::string& name) { CHECK_EQ(NumOutputs(), 1) << "Only composition of value function is supported currently"; @@ -261,13 +271,17 @@ void Symbol::Compose(const std::vector& args, if (this->is_atomic()) { // atomic symbol do not have place holder for all the arguments std::vector req_args = heads_[0].source->op->ListArguments(); - CHECK_EQ(args.size(), req_args.size()) + CHECK_LE(args.size(), req_args.size()) << "Incorrect number of arguments, requires " << req_args.size() << ", provided " << args.size(); - heads_[0].source->inputs.resize(args.size()); + heads_[0].source->inputs.resize(req_args.size()); for (size_t i = 0; i < args.size(); ++i) { heads_[0].source->inputs[i] = args[i].heads_[0]; } + for (size_t i = args.size(); i < req_args.size(); ++i) { + heads_[0].source->inputs[i] = DataEntry( + std::make_shared(nullptr, DefaultVarName(name, req_args[i])), 0); + } } else { // find all the place holders size_t arg_counter = 0; @@ -325,15 +339,8 @@ void Symbol::Compose(const std::unordered_map& kwargs, heads_[0].source->inputs[i] = iter->second.heads_[0]; ++nmatched; } else { - // create a variable node - // TODO(bing): think of naming convention - if (name.length() == 0) { - heads_[0].source->inputs[i] = DataEntry( - std::make_shared(nullptr, req_args[i]), 0); - } else { - heads_[0].source->inputs[i] = DataEntry( - std::make_shared(nullptr, name + '_' + req_args[i]), 0); - } + heads_[0].source->inputs[i] = DataEntry( + std::make_shared(nullptr, DefaultVarName(name, req_args[i])), 0); } } // if things goes wrong recover the old state diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index 350adfde274e..40304187a5fa 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -9,12 +9,12 @@ # symbol net batch_size = 100 data = mx.symbol.Variable('data') -fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) -act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") -fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) -act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") -fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) -softmax = mx.symbol.Softmax(data = fc3, name = 'sm') +fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128) +act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") +fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64) +act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu") +fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10) +softmax = mx.symbol.Softmax(fc3, name = 'sm') num_round = 4 prefix = './mlp' From f85a59bf8865767f31bf34a788fc16b0c33a2adb Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Mon, 21 Sep 2015 16:31:16 -0600 Subject: [PATCH 06/20] [BUG FIX] Slice in float type --- example/cifar10/cifar10.py | 18 +++++++++--------- example/{imagenet => notebooks}/alexnet.ipynb | 0 .../composite_symbol.ipynb | 0 python/mxnet/model.py | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) rename example/{imagenet => notebooks}/alexnet.ipynb (100%) rename example/{python-howto => notebooks}/composite_symbol.ipynb (100%) diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index b99c49ea7423..bf32308846e1 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -5,6 +5,7 @@ sys.path.insert(0, "../../python/") sys.path.append("../../tests/python/common") # import library +import logging import mxnet as mx import get_data import time @@ -59,10 +60,6 @@ [39] train-error:0.00125879 val-error:0.0833 [40] train-error:0.000699329 val-error:0.0842 """ -def CalAcc(out, label): - pred = np.argmax(out, axis=1) - return np.sum(pred == label) * 1.0 / out.shape[0] - np.random.seed(1812) @@ -178,11 +175,14 @@ def RandomInit(narray): preprocess_threads=1) def test_cifar(): - model = mx.model.MXNetModel(ctx=mx.gpu(), - symbol=loss, data=(batch_size, 3, 28, 28), - optimizer="sgd", num_round = epoch, batch_size = batch_size, - learning_rate=0.05, momentum=0.9, weight_decay=0.00001) - model.fit(X=train_dataiter, eval_set=test_dataiter, eval_metric=CalAcc) + logging.basicConfig(level=logging.DEBUG) + console = logging.StreamHandler() + console.setLevel(logging.DEBUG) + logging.getLogger('').addHandler(console) + # get model from symbol + model = mx.model.FeedForward(ctx=mx.gpu(), symbol=loss, num_round = epoch, + learning_rate=0.05, momentum=0.9, wd=0.00001) + model.fit(X=train_dataiter, eval_data=test_dataiter) if __name__ == "__main__": diff --git a/example/imagenet/alexnet.ipynb b/example/notebooks/alexnet.ipynb similarity index 100% rename from example/imagenet/alexnet.ipynb rename to example/notebooks/alexnet.ipynb diff --git a/example/python-howto/composite_symbol.ipynb b/example/notebooks/composite_symbol.ipynb similarity index 100% rename from example/python-howto/composite_symbol.ipynb rename to example/notebooks/composite_symbol.ipynb diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 1fd73eb4033e..11e440988f51 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -104,8 +104,8 @@ def _split_input_slice(input_shape, num_split): slices = [] shapes = [] for k in range(num_split): - begin = min(k * step, batch_size) - end = min((k+1) * step, batch_size) + begin = int(min(k * step, batch_size)) + end = int(min((k+1) * step, batch_size)) if begin == end: raise ValueError('Too many slices such that some splits are empty') slices.append(slice(begin, end)) From 3911d92b39928e61e60574d86fcda41f2a51003a Mon Sep 17 00:00:00 2001 From: Mu Li Date: Mon, 21 Sep 2015 20:24:30 -0400 Subject: [PATCH 07/20] use mult-gpus in cifar10.py --- example/cifar10/cifar10.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index bf32308846e1..f00b5bc36650 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -180,7 +180,8 @@ def test_cifar(): console.setLevel(logging.DEBUG) logging.getLogger('').addHandler(console) # get model from symbol - model = mx.model.FeedForward(ctx=mx.gpu(), symbol=loss, num_round = epoch, + gpus = [mx.gpu(i) for i in range(2)] + model = mx.model.FeedForward(ctx=gpus, symbol=loss, num_round = epoch, learning_rate=0.05, momentum=0.9, wd=0.00001) model.fit(X=train_dataiter, eval_data=test_dataiter) From 6a434ce3ccf4bbef46cbde291c78cb4859070cfb Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Mon, 21 Sep 2015 17:09:24 -0600 Subject: [PATCH 08/20] move mnist to ualberta server --- example/notebooks/alexnet.ipynb | 19 +++++++++++-------- tests/python/common/get_data.py | 22 +++++++++------------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/example/notebooks/alexnet.ipynb b/example/notebooks/alexnet.ipynb index b7bb6bf266c2..c030d873cd08 100644 --- a/example/notebooks/alexnet.ipynb +++ b/example/notebooks/alexnet.ipynb @@ -29,7 +29,6 @@ }, "outputs": [], "source": [ - "%matplotlib inline\n", "import mxnet as mx" ] }, @@ -402,7 +401,7 @@ } ], "source": [ - "mx.visualization.plot_network(\"AlexNet\", softmax)" + "mx.viz.plot_network(\"AlexNet\", softmax)" ] }, { @@ -425,28 +424,32 @@ "# We set batch size for to 256\n", "batch_size = 256\n", "# We need to set correct path to image record file\n", - "# For ```mean_image```. if it doesn't exist, the iterator will generate one. Usually on normal HDD, it costs less than 10 minutes\n", + "# For ```mean_image```. if it doesn't exist, the iterator will generate one\n", + "# On HDD, single thread is able to process 800 images / sec\n", "# the input shape is in format (channel, height, width)\n", "# rand_crop option make source image randomly cropped to input_shape (3, 224, 224)\n", "# rand_mirror option make source image randomly mirrored\n", "# We use 2 threads to processing our data\n", "train_dataiter = mx.io.ImageRecordIter(\n", + " shuffle=True,\n", " path_imgrec=\"./Data/ImageNet/train.rec\",\n", " mean_img=\"./Data/ImageNet/mean_224.bin\",\n", " rand_crop=True,\n", " rand_mirror=True,\n", - " input_shape=(3, 224, 224),\n", + " data_shape=(3, 224, 224),\n", " batch_size=batch_size,\n", - " nthread=2)\n", + " prefetch_buffer=4,\n", + " preprocess_threads=2)\n", "# similarly, we can declare our validation iterator\n", "val_dataiter = mx.io.ImageRecordIter(\n", " path_imgrec=\"./Data/ImageNet/val.rec\",\n", " mean_img=\"./Data/ImageNet/mean_224.bin\",\n", " rand_crop=False,\n", " rand_mirror=False,\n", - " input_shape=(3, 224, 224),\n", + " data_shape=(3, 224, 224),\n", " batch_size=batch_size,\n", - " nthread=2)" + " prefetch_buffer=4,\n", + " preprocess_threads=2)" ] }, { @@ -531,7 +534,7 @@ "# When we use data iterator, we don't need to set y because label comes from data iterator directly\n", "# In this case, eval_data is also a data iterator\n", "# We will use accuracy to measure our model's performace\n", - "model.fit(X=train_dataiter, eval_data=val_dataiter, eval_metric='acc', verbose=True)\n", + "model.fit(X=train_dataiter, eval_data=val_dataiter, eval_metric='acc')\n", "# You need to wait for a while to get the result" ] }, diff --git a/tests/python/common/get_data.py b/tests/python/common/get_data.py index 270132e448b8..65e8ac59ad6f 100644 --- a/tests/python/common/get_data.py +++ b/tests/python/common/get_data.py @@ -14,18 +14,14 @@ def GetMNIST_pkl(): def GetMNIST_ubyte(): if not os.path.isdir("data/"): os.system("mkdir data/") - if not os.path.exists('data/train-images-idx3-ubyte'): - os.system("wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz -P data/") - os.system("gunzip data/train-images-idx3-ubyte.gz") - if not os.path.exists('data/train-labels-idx1-ubyte'): - os.system("wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz -P data/") - os.system("gunzip data/train-labels-idx1-ubyte.gz") - if not os.path.exists('data/t10k-images-idx3-ubyte'): - os.system("wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz -P data/") - os.system("gunzip data/t10k-images-idx3-ubyte.gz") - if not os.path.exists('data/t10k-labels-idx1-ubyte'): - os.system("wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz -P data/") - os.system("gunzip data/t10k-labels-idx1-ubyte.gz") + if (not os.path.exists('data/train-images-idx3-ubyte')) or \ + (not os.path.exists('data/train-labels-idx1-ubyte')) or \ + (not os.path.exists('data/t10k-images-idx3-ubyte')) or \ + (not os.path.exists('data/t10k-labels-idx1-ubyte')): + os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/mnist.zip -P data/") + os.chdir("./data") + os.system("unzip -u mnist.zip") + os.chdir("..") # download cifar def GetCifar10(): @@ -34,5 +30,5 @@ def GetCifar10(): if not os.path.exists('data/cifar10.zip'): os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/cifar10.zip -P data/") os.chdir("./data") - os.system("unzip cifar10.zip") + os.system("unzip -u cifar10.zip") os.chdir("..") From 1aaf30a62e47b46f8657ffaee67b0b9829dc9a6c Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 16:31:51 -0700 Subject: [PATCH 09/20] add name manager --- python/mxnet/__init__.py | 1 + python/mxnet/context.py | 37 ++++++++++++------- python/mxnet/name.py | 77 ++++++++++++++++++++++++++++++++++++++++ python/mxnet/symbol.py | 4 +++ 4 files changed, 107 insertions(+), 12 deletions(-) create mode 100644 python/mxnet/name.py diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 89dd2c09da79..e9630b678ee0 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -12,6 +12,7 @@ from .base import MXNetError from . import base from . import ndarray +from . import name from . import symbol from . import kvstore as kv from . import io diff --git a/python/mxnet/context.py b/python/mxnet/context.py index fff45dc7b895..485d292aa203 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -3,23 +3,36 @@ from __future__ import absolute_import class Context(object): - """Context representing device and device id in mxnet""" + """Constructing a context. + + Parameters + ---------- + device_type : {'cpu', 'gpu'} or Context. + String representing the device type + + device_id : int (default=0) + The device id of the device, needed for GPU + + Note + ---- + Context can also be used a way to change default context. + + Examples + -------- + Switch default context example: + >>> # array on cpu + >>> cpu_array = mx.md.ones((2, 3)) + >>> # switch default context to GPU(2) + >>> with mx.Context(mx.gpu(2)): + >>> gpu_array = mx.md.ones((2, 3)) + >>> gpu_array.context + Context(device_type=gpu, device_id=2) + """ # static class variable default_ctx = None devtype2str = {1: 'cpu', 2: 'gpu'} devstr2type = {'cpu': 1, 'gpu': 2} - def __init__(self, device_type, device_id=0): - """Constructing a context. - - Parameters - ---------- - device_type : str (can be 'cpu' or 'gpu') - a string representing the device type - - device_id : int (default=0) - the device id of the device, needed for GPU - """ if isinstance(device_type, Context): self.device_typeid = device_type.device_typeid self.device_id = device_type.device_id diff --git a/python/mxnet/name.py b/python/mxnet/name.py new file mode 100644 index 000000000000..f4e5109d20e1 --- /dev/null +++ b/python/mxnet/name.py @@ -0,0 +1,77 @@ +# coding: utf-8 +"""Automatic naming support for symbolic API.""" + +class NameManager(object): + """NameManager to do automatic naming. + + User can also inheritate this object to change naming behavior. + """ + current = None + + def __init__(self): + self._counter = {} + self._old_manager = None + + def get(self, name, hint): + """Get the canonical name for a symbol. + + This is default implementation. + When user specified a name, + the user specified name will be used. + + When user did not, we will automatically generate a + name based on hint string. + + Parameters + ---------- + name : str or None + The name user specified. + + hint : str + A hint string, which can be used to generate name. + + Returns + ------- + full_name : str + A canonical name for the user. + """ + if name: + return name + if hint not in self._counter: + self._counter[hint] = 0 + name = '%s%d' % (hint, self._counter[hint]) + self._counter[hint] += 1 + return name + + def __enter__(self): + self._old_manager = NameManager.current + NameManager.current = self + return self + + def __exit__(self, ptype, value, trace): + assert self._old_manager + NameManager.current = self._old_manager + + +class Prefix(NameManager): + """A name manager that always attach a prefix to all names. + + Examples + -------- + >>> import mxnet as mx + >>> data = mx.symbol.Variable('data') + >>> with mx.name.Prefix('mynet_'): + net = mx.symbol.FullyConnected(data, num_hidden=10, name='fc1') + >>> net.list_arguments() + ['data', 'mynet_fc1_weight', 'mynet_fc1_bias'] + """ + def __init__(self, prefix): + super(Prefix, self).__init__() + self._prefix = prefix + + def get(self, name, hint): + name = super(Prefix, self).get(name, hint) + return self._prefix + name + +# initialize the default name manager +NameManager.current = NameManager() diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index a7d53f28fc1a..44318c66200c 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -13,6 +13,7 @@ from .base import c_array, c_str, mx_uint, py_str, string_types from .base import NDArrayHandle, ExecutorHandle, SymbolHandle from .base import check_call, ctypes2docstring +from .name import NameManager from .context import Context from .ndarray import NDArray, zeros from .executor import Executor @@ -128,6 +129,7 @@ def _compose(self, *args, **kwargs): the resulting symbol """ name = kwargs.pop('name', None) + if name: name = c_str(name) if len(args) != 0 and len(kwargs) != 0: @@ -752,6 +754,8 @@ def creator(*args, **kwargs): ' instead of keyword arguments.') s = Symbol(sym_handle) + hint = func_name.lower() + name = NameManager.current.get(name, hint) s._compose(*args, name=name, **symbol_kwargs) return s From 70814edde067639fb7aefa1c9ac4a4afc412ab02 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 17:12:00 -0700 Subject: [PATCH 10/20] fix optimizer under multi device --- python/mxnet/initializer.py | 31 ++++++++++++++++------------- python/mxnet/kvstore.py | 1 + python/mxnet/model.py | 22 +++++++++++++++------ python/mxnet/name.py | 1 + python/mxnet/optimizer.py | 36 ++++++++++++++++++++++++---------- python/mxnet/visualization.py | 2 ++ tests/python/train/test_mlp.py | 2 +- 7 files changed, 64 insertions(+), 31 deletions(-) diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index bd64413ca295..fa13594926f7 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -1,4 +1,7 @@ -# pylint: skip-file +# coding: utf-8 +"""Initialization helper for mxnet""" +from __future__ import absolute_import + import numpy as np from .base import string_types from .ndarray import NDArray @@ -36,17 +39,17 @@ def __call__(self, name, arr): self._init_zero(name, arr) else: self._init_default(name, arr) - - def _init_zero(self, name, arr): + # pylint: disable=no-self-use, missing-docstring + def _init_zero(self, _, arr): arr[:] = 0.0 - def _init_bias(self, name, arr): + def _init_bias(self, _, arr): arr[:] = 0.0 - def _init_gamma(self, name, arr): + def _init_gamma(self, _, arr): arr[:] = 1.0 - def _init_beta(self, name, arr): + def _init_beta(self, _, arr): arr[:] = 0.0 def _init_weight(self, name, arr): @@ -55,7 +58,7 @@ def _init_weight(self, name, arr): def _init_default(self, name, _): raise ValueError('Unknown initialization pattern for %s' % name) - + # pylint: enable=no-self-use, missing-docstring class Uniform(Initializer): """Initialize the weight with uniform [-scale, scale] @@ -68,8 +71,8 @@ class Uniform(Initializer): def __init__(self, scale=0.07): self.scale = scale - def _init_weight(self, name, arr): - random.uniform(-scale, scale, out=arr) + def _init_weight(self, _, arr): + random.uniform(-self.scale, self.scale, out=arr) class Normal(Initializer): @@ -81,10 +84,10 @@ class Normal(Initializer): Standard deviation for gaussian distribution. """ def __init__(self, sigma=0.01): - super().__init__(sigma = sigma) + self.sigma = sigma - def _init_weight(self, name, arr): - random.normal(0, sigma, out=arr) + def _init_weight(self, _, arr): + random.normal(0, self.sigma, out=arr) class Xavier(Initializer): @@ -95,6 +98,6 @@ def _init_weight(self, _, arr): # [in, out] for fullc shape = arr.shape fan_in, fan_out = shape[1], shape[0] - s = np.sqrt(6. / (fan_in + fan_out)) - random.uniform(-s, s, out=arr) + scale = np.sqrt(6. / (fan_in + fan_out)) + random.uniform(-scale, scale, out=arr) diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 8f4822d85a69..ac5691d37909 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -2,6 +2,7 @@ # pylint: disable=invalid-name, global-statement """ KVStore in mxnet """ from __future__ import absolute_import + import ctypes from .ndarray import NDArray from .base import _LIB diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 11e440988f51..3466bf9362d9 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -1,6 +1,8 @@ # pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals # pylint: disable=too-many-branches, too-many-statements, unused-argument """MXNet model module""" +from __future__ import absolute_import + import numpy as np import time import logging @@ -201,11 +203,18 @@ def _train_multi_device(symbol, ctx, input_shape, aux_params[name].copyto(w) # ky value store kv = kvstore.create() if num_device != 1 else None + opt_state_blocks = [] # If there are multiple devices, initialize the weights. for index, pair in enumerate(zip(arg_blocks, grad_blocks)): - arg, grad = pair - if kv and grad[0] is not None: - kv.init(index, arg[0]) + arg_list, grad_list = pair + if kv and grad_list[0] is not None: + kv.init(index, arg_list[0]) + # attach state direct to weight + opt_list = [optimizer.create_state(index, w) for w in arg_list] + opt_state_blocks.append(opt_list) + else: + opt_state_blocks.append(None) + # Input and output data structure data_index, label_index = _check_arguments(symbol) merged_shape = list(train_execs[0].outputs[0].shape) @@ -244,9 +253,10 @@ def _train_multi_device(symbol, ctx, input_shape, kv.push(index, grad_list) # pull back the sum, to the same locations. kv.pull(index, grad_list) - # optimize - for w, g in zip(arg_list, grad_list): - optimizer.update(index, w, g) + opt_list = opt_state_blocks[index] + # optimizea + for w, g, state in zip(arg_list, grad_list, opt_list): + optimizer.update(index, w, g, state) # evaluate at end, so out_cpu_array can lazy copy eval_metric.update(out_cpu_array, label) diff --git a/python/mxnet/name.py b/python/mxnet/name.py index f4e5109d20e1..b0c8bff52a8a 100644 --- a/python/mxnet/name.py +++ b/python/mxnet/name.py @@ -1,5 +1,6 @@ # coding: utf-8 """Automatic naming support for symbolic API.""" +from __future__ import absolute_import class NameManager(object): """NameManager to do automatic naming. diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 8cc3d1b4f241..d1f0ae4ef246 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -1,4 +1,4 @@ -# pylint: disable=fixme, invalid-name +# pylint: disable=fixme, invalid-name, unused-argument """Common Optimization algorithms with regularizations.""" from .ndarray import NDArray, zeros @@ -31,7 +31,6 @@ class SGD(Optimizer): rescale_grad : float, optional rescaling factor of gradient. - """ def __init__(self, learning_rate=0.01, momentum=0.0, wd=0.0001, rescale_grad=1): @@ -41,7 +40,21 @@ def __init__(self, learning_rate=0.01, momentum=0.0, self.rescale_grad = rescale_grad self.momentums = {} - def update(self, index, weight, grad): + def create_state(self, index, weight): + """Create additional optimizer state such as momentum. + + Parameters + ---------- + weight : NDArray + The weight data + + """ + if self.momentum == 0.0: + return None + else: + return zeros(weight.shape, weight.context) + + def update(self, index, weight, grad, state): """Update the parameters. Parameters @@ -55,17 +68,20 @@ def update(self, index, weight, grad): grad : NDArray grad ndarray + state : NDArray or other objects returned by init_state + The auxiliary state used in optimization. """ # TODO(bing) implement wd_bias, wd_gamma, wd_beta assert(isinstance(weight, NDArray)) assert(isinstance(grad, NDArray)) - - if index not in self.momentums: - self.momentums[index] = zeros(grad.shape, grad.context) - mom = self.momentums[index] - mom[:] *= self.momentum - mom[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight) - weight[:] += mom + if state: + mom = state + mom[:] *= self.momentum + mom[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight) + weight[:] += mom + else: + assert self.momentum == 0.0 + weight[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight) def create(name, rescale_grad=1, **kwargs): diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 86fc53c37311..3992a241b69f 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -2,6 +2,8 @@ # pylint: disable=invalid-name, protected-access, too-many-locals, fixme # pylint: disable=unused-argument, too-many-branches, too-many-statements """Visualization module""" +from __future__ import absolute_import + from .symbol import Symbol import json import re diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index 40304187a5fa..b0849a3e81d9 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -18,7 +18,7 @@ num_round = 4 prefix = './mlp' -model = mx.model.FeedForward(softmax, mx.cpu(), +model = mx.model.FeedForward(softmax, [mx.cpu()] * 2, num_round=num_round, learning_rate=0.01, wd=0.0004, momentum=0.9) From 1e6e1366c9b62b9094b80c2b35cad47f0a22c2ea Mon Sep 17 00:00:00 2001 From: Mu Li Date: Mon, 21 Sep 2015 20:31:30 -0400 Subject: [PATCH 11/20] multi-gpus in cifar10 --- example/cifar10/cifar10.py | 4 +- example/cifar10/cifar10_multi_gpus.py | 262 -------------------------- 2 files changed, 3 insertions(+), 263 deletions(-) delete mode 100644 example/cifar10/cifar10_multi_gpus.py diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index f00b5bc36650..db35927196bc 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -157,6 +157,8 @@ def RandomInit(narray): get_data.GetCifar10() batch_size = 128 epoch = 3 +num_gpus = 1 + train_dataiter = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar_mean.bin", @@ -180,7 +182,7 @@ def test_cifar(): console.setLevel(logging.DEBUG) logging.getLogger('').addHandler(console) # get model from symbol - gpus = [mx.gpu(i) for i in range(2)] + gpus = [mx.gpu(i) for i in range(num_gpus)] model = mx.model.FeedForward(ctx=gpus, symbol=loss, num_round = epoch, learning_rate=0.05, momentum=0.9, wd=0.00001) model.fit(X=train_dataiter, eval_data=test_dataiter) diff --git a/example/cifar10/cifar10_multi_gpus.py b/example/cifar10/cifar10_multi_gpus.py deleted file mode 100644 index e68e6edfc77b..000000000000 --- a/example/cifar10/cifar10_multi_gpus.py +++ /dev/null @@ -1,262 +0,0 @@ -# Pylint: skip-file -import numpy as np -import mxnet as mx -import copy -import sys -sys.path.append("../../tests/python/common") -import get_data -import time - -# use multiple devices -num_devs = 4 -devs = [mx.gpu(i) for i in range(num_devs)] -mx.kv.start() - -# define the network -conv_cnt = 1 -concat_cnt = 1 -pool_cnt = 1 - -def ConvFactory(**kwargs): - global conv_cnt - param = copy.copy(kwargs) - act = param["act_type"] - del param["act_type"] - param["workspace"] = 256 - param["name"] = "conv%d" % conv_cnt - conv = mx.symbol.Convolution(**param) - bn = mx.symbol.BatchNorm(data = conv, name="bn%d" % conv_cnt) - relu = mx.symbol.Activation(data = bn, name = "%s%d" % (act, conv_cnt), act_type=act) - conv_cnt += 1 - return relu - -def DownsampleFactory(data, ch_3x3, stride = 2): - global pool_cnt - global concat_cnt - param = {} - # conv 3x3 - param["kernel"] = (3, 3) - param["stride"] = (stride, stride) - param["num_filter"] = ch_3x3 - param["act_type"] = "relu" - param["data"] = data - param["pad"] = (1, 1) - conv3x3 = ConvFactory(**param) - # pool - del param["num_filter"] - del param["act_type"] - del param["pad"] - param["pool_type"] = "max" - param["name"] = "pool%d" % pool_cnt - pool = mx.symbol.Pooling(**param) - pool_cnt += 1 - # concat - concat = mx.symbol.Concat(*[conv3x3, pool], name="concat%d" % concat_cnt) - concat_cnt += 1 - return concat - -def SimpleFactory(data, ch_1x1, ch_3x3): - global concat_cnt - param = {} - # 1x1 - param["kernel"] = (1, 1) - param["num_filter"] = ch_1x1 - param["pad"] = (0, 0) - param["stride"] = (1, 1) - param["act_type"] = "relu" - param["data"] = data - conv1x1 = ConvFactory(**param) - - # 3x3 - param["kernel"] = (3, 3) - param["num_filter"] = ch_3x3 - param["pad"] = (1, 1) - conv3x3 = ConvFactory(**param) - - #concat - concat = mx.symbol.Concat(*[conv1x1, conv3x3], name="concat%d" % concat_cnt) - concat_cnt += 1 - return concat - -data = mx.symbol.Variable(name="data") -conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu") -in3a = SimpleFactory(conv1, 32, 32) -in3b = SimpleFactory(in3a, 32, 48) -in3c = DownsampleFactory(in3b, 80) -in4a = SimpleFactory(in3c, 112, 48) -in4b = SimpleFactory(in4a, 96, 64) -in4c = SimpleFactory(in4b, 80, 80) -in4d = SimpleFactory(in4c, 48, 96) -in4e = DownsampleFactory(in4d, 96) -in5a = SimpleFactory(in4e, 176, 160) -in5b = SimpleFactory(in5a, 176, 160) -pool = mx.symbol.Pooling(data=in5b, pool_type="avg", kernel=(7,7), name="pool%d" % pool_cnt) -flatten = mx.symbol.Flatten(data=pool, name="flatten1") -fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1") -loss = mx.symbol.Softmax(data=fc, name="loss") - -# define model updater - -def momentum(learning_rate=.01, weight_decay=0.0001, momentum=0.9): - """Stochastic Gradient Descent (SGD) updates with momentum - """ - momentums = {} - def momentum_update(key, grad, weight): - # weight += - learning_rate * (grad + weight_decay * weight) - if not momentums.has_key(key): - momentums[key] = mx.nd.zeros(grad.shape) - mom = momentums[key] - mom *= momentum - mom += - learning_rate * (grad + weight_decay * weight) - weight += mom - return momentum_update - -updater = momentum( - learning_rate = .05, weight_decay = .0001, momentum = 0.9) -mx.kv.set_updater(updater) - -# infer shape -batch_size = 196 -batch_size -= (batch_size % num_devs) -data_shape = (batch_size / num_devs, 3, 28, 28) - -# create executors for devices -executors = [loss.simple_bind(d, data = mx.nd.empty(data_shape, d)) for d in devs] - -# find the params needed to be synchronized between devices -param_names = loss.list_arguments() -sync_prefix = ["weight", "bias", "beta", "gamma"] -sync_indices = [index for index, name in enumerate(param_names) - if any(prefix in name for prefix in sync_prefix)] - -sync_weights = [[e.list_arguments()[0][i] for e in executors] for i in sync_indices] -sync_grads = [[e.list_arguments()[1][i] for e in executors] for i in sync_indices] - - -# init model -weights = executors[0].list_arguments()[0] -for idx in sync_indices: - shape = weights[idx].shape - val = mx.nd.zeros(shape) - if "weight" in param_names[idx]: - val[:] = np.random.uniform(-0.1, 0.1, shape) - elif "gamma" in param_names[idx]: - val[:] = 1.0 - mx.kv.init(idx, val) - -# data reader -get_data.GetCifar10() - -train_dataiter = mx.io.ImageRecordIter( - path_imgrec="data/cifar/train.rec", - mean_img="data/cifar/cifar_mean.bin", - rand_crop=True, - rand_mirror=True, - shuffle=False, - data_shape=(3,28,28), - batch_size=batch_size, - preprocess_threads=4, - prefetch_buffer=6) -test_dataiter = mx.io.ImageRecordIter( - path_imgrec="data/cifar/test.rec", - mean_img="data/cifar/cifar_mean.bin", - rand_crop=False, - rand_mirror=False, - shuffle=False, - data_shape=(3,28,28), - batch_size=batch_size, - preprocess_threads=4, - prefetch_buffer=6) - -def progress(count, total, epoch, tic): - bar_len = 50 - filled_len = int(round(bar_len * count / float(total))) - percents = round(100.0 * count / float(total), 1) - bar = '=' * filled_len + '-' * (bar_len - filled_len) - toc = time.time() - speed = batch_size / float(toc - tic) - suffix = "Epoch %d, Speed: %.2f pic/sec" % (epoch, speed) - sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', suffix)) - -def cal_acc(out, label): - pred = np.argmax(out, axis=1) - return np.sum(pred == label) * 1.0 / out.shape[0] - -def train(): - epoch = 7 - acc_train = 0. - acc_val = 0. - k = batch_size / num_devs - batch_splits = [range(d*k, (d+1)*k) for d in range(num_devs)] - print("Start training...") - data_in = [e.list_arguments()[0][param_names.index('data')] for e in executors] - label_in = [e.list_arguments()[0][param_names.index('loss_label')] for e in executors] - - for i in range(epoch): - # train - start = time.time() - train_acc = 0.0 - val_acc = 0.0 - train_count = 0 - val_count = 0 - all_train_bacth = round(50000 / float(batch_size/num_devs) + 1) - - for data, label in train_dataiter: - tic = time.time() - # pull weight - mx.kv.pull(sync_indices, out = sync_weights) - - # forward and backword - data = data.asnumpy() - label = label.asnumpy().flatten() - for d in range(num_devs): - rows = batch_splits[d] - data_in[d][:] = data[rows, :] - label_in[d][:] = label[rows] - executors[d].forward() - executors[d].backward() - - # normalize gradient - for grads in sync_grads: - for g in grads: - g /= batch_size - - # push gradient - mx.kv.push(sync_indices, sync_grads) - - # evaluate - for d in range(num_devs): - train_acc += cal_acc(executors[d].outputs[0].asnumpy(), - label[batch_splits[d]]) - train_count += 1 - - progress(train_count, all_train_bacth, i, tic) - - # evaluate - for data, label in val_dataiter: - # forward - data = data.asnumpy() - label = label.asnumpy().flatten() - for d in range(num_devs): - rows = batch_splits[d] - data_in[d][:] = data[rows,:] - executors[d].forward() - - # eval - for d in range(num_devs): - val_acc += cal_acc(executors[d].outputs[0].asnumpy(), - label[batch_splits[d]]) - val_count += 1 - - sys.stdout.write('\n') - - print("Train Acc: %g, Valid Acc: %g, Time: %g sec" % ( - train_acc / train_count, - val_acc / val_count, - time.time() - start)) - - train_dataiter.reset() - val_dataiter.reset() - -if __name__ == "__main__": - train() From ec2df592c807136609484b00f86de398aec7144a Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 16:31:51 -0700 Subject: [PATCH 12/20] BUGFIX --- python/mxnet/model.py | 5 +++-- tests/python/train/test_mlp.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 3466bf9362d9..a84244a2a777 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -207,8 +207,9 @@ def _train_multi_device(symbol, ctx, input_shape, # If there are multiple devices, initialize the weights. for index, pair in enumerate(zip(arg_blocks, grad_blocks)): arg_list, grad_list = pair - if kv and grad_list[0] is not None: - kv.init(index, arg_list[0]) + if grad_list[0] is not None: + if kv: + kv.init(index, arg_list[0]) # attach state direct to weight opt_list = [optimizer.create_state(index, w) for w in arg_list] opt_state_blocks.append(opt_list) diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index b0849a3e81d9..bd635e980297 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -18,7 +18,8 @@ num_round = 4 prefix = './mlp' -model = mx.model.FeedForward(softmax, [mx.cpu()] * 2, +model = mx.model.FeedForward(softmax, + [mx.cpu(i) for i in range(2)], num_round=num_round, learning_rate=0.01, wd=0.0004, momentum=0.9) From 5b13ccd683721403a0cacb8098206fc78435fadf Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 18:47:44 -0700 Subject: [PATCH 13/20] More on model doc --- doc/python/index.md | 2 + doc/python/model.md | 116 ++++++++++++++++++++ python/mxnet/model.py | 53 ++++++++- tests/python/test_mlp_multi_devices.py.bak | 120 --------------------- tests/python/train/test_mlp.py | 24 +++-- 5 files changed, 185 insertions(+), 130 deletions(-) create mode 100644 doc/python/model.md delete mode 100644 tests/python/test_mlp_multi_devices.py.bak diff --git a/doc/python/index.md b/doc/python/index.md index 3a1713c81afa..82e1eaaa9ed1 100644 --- a/doc/python/index.md +++ b/doc/python/index.md @@ -20,4 +20,6 @@ Python API Documents -------------------- * [NDArray API](ndarray.md) * [Symbolic API](symbol.md) +* [KVStore API](kvstore.md) * [Data Loading API](io.md) +* [Model API](model.md) \ No newline at end of file diff --git a/doc/python/model.md b/doc/python/model.md new file mode 100644 index 000000000000..bd15379eeeee --- /dev/null +++ b/doc/python/model.md @@ -0,0 +1,116 @@ +MXNet Python Model API +====================== +The model API in mxnet as not really an API. +It is a thin wrapper build on top of [ndarray](ndarray.md) and [symbolic](symbol.md) +modules to make neural network training easy. + +* [Train a Model](#overloaded-operators) introduces operator overloading of symbols +* [Serialization](#serialization) introduces how to save and load symbols. +* [Multiple Outputs](#multiple-outputs) introduces how to configure multiple outputs +* [API Reference](#api-reference) gives reference to all functions. +* [Symbol Object Document](#mxnet.symbol.Symbol) gives API reference to the Symbol Object. + + +Train a Model +------------- +To train a model, you can follow two steps, first a configuration using symbol, +then call ```model.Feedforward.create``` to create a model for you. +The following example creates a two layer neural networks. + +```python +batch_size = 100 +data = mx.symbol.Variable('data') +fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128) +act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") +fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64) +softmax = mx.symbol.Softmax(fc2, name = 'sm') + +model = mx.model.FeedForward.create( + softmax, + X=data_set, + num_round=num_round, + learning_rate=0.01) +``` + +You can also use scikit-learn style construct and fit function to create a model. +For more information, you can refer to [Model API Reference](#model-api-reference). + +Save the Model +-------------- +It is important to save your work after the job done. +To save the model, you can directly pickle it if you like the pythonic way. +We also provide a save and load function. + +```python +# save a model to mymodel-symbol.json and mymodel-0100.params +prefix = 'mymodel' +model.save(prefix, 100) + +# load model back +model_loaded = mx.model.FeedForward.load(prefix, 100) +``` +The advantage of this save and load function is they are language agnostic, +and you should be able to save and load directly into cloud storage such as S3 and HDFS. + +Periodically Checkpoint +----------------------- +It is also helpful to periodically checkpoint your model after each iteration. +To do so, you can simply add a checkpoint callback to the function. +The training process will automatically checkpoint to the specified place after +each iteration. + +```python +prefix='models/chkpt' +model = mx.model.FeedForward.create( + softmax, + X=data_set, + iter_end_callback=mx.model.do_checkpoint(prefix), + num_round=num_round, + learning_rate=0.01) +``` +You can load the model checkpoint later using ```Feedforward.load```. + +Use Multiple Devices +-------------------- +Simply set ```ctx``` to be the list of devices you like to train on. + +```python +devices = [mx.gpu(i) for i in range(num_device)] +model = mx.model.FeedForward.create( + softmax, + X=dataset, + ctx=devices, + ...) +``` + +Initializer API Reference +------------------------- + +```eval_rst +.. automodule:: mxnet.initializer + :members: +``` + +Evaluation Metric API Reference +------------------------------- + +```eval_rst +.. automodule:: mxnet.metric + :members: +``` + +Optimizer API Reference +----------------------- + +```eval_rst +.. automodule:: mxnet.optimizer + :members: +``` + +Model API Reference +------------------- + +```eval_rst +.. automodule:: mxnet.model + :members: +``` diff --git a/python/mxnet/model.py b/python/mxnet/model.py index a84244a2a777..df07512af64d 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -529,7 +529,7 @@ def predict(self, X): def fit(self, X, y=None, eval_data=None, eval_metric='acc', iter_end_callback=None, logger=None): - """fit the model + """Fit the model. Parameters ---------- @@ -629,3 +629,54 @@ def load(prefix, iteration, ctx=None): return FeedForward(symbol, ctx=ctx, arg_params=arg_params, aux_params=aux_params) + @staticmethod + def create(symbol, X, y=None, ctx=None, + num_round=None, optimizer='sgd', initializer=Xavier(), + eval_data=None, eval_metric='acc', iter_end_callback=None, + logger=None, **kwargs): + """Functional style to create a model. + + This function will be more consistent with functional + languages such as R, where mutation is not allowed. + + Parameters + ---------- + symbol : Symbol + The symbol configuration of computation network. + + X : DataIter + Training data + + y : numpy.ndarray, optional + If X is numpy.ndarray y is required to set + + ctx : Context or list of Context, optional + The device context of training and prediction. + To use multi GPU training, pass in a list of gpu contexts. + + num_round : int, optional + Training parameter, number of training rounds(iterations). + + optimizer : str or Optimizer, optional + Training parameter, name or optimizer object for training. + + initializier : initializer function, optional + Training parameter, the initialization scheme used. + + eval_data : DataIter or numpy.ndarray pair + If eval_set is numpy.ndarray pair, it should be (valid_data, valid_label) + + eval_metric : function + Evaluation metric function. + + iter_end_callback : callable(iteration, symbol, arg_params, aux_states) + A callback that is invoked at end of each iteration. + This can be used to checkpoint model each iteration. + + logger : logging logger, optional + """ + model = FeedForward(symbol, ctx=ctx, num_round=num_round, + optimizer=optimizer, initializer=initializer, **kwargs) + model.fit(X, y, eval_data=eval_data, eval_metric=eval_metric, + iter_end_callback=iter_end_callback, logger=logger) + return model diff --git a/tests/python/test_mlp_multi_devices.py.bak b/tests/python/test_mlp_multi_devices.py.bak deleted file mode 100644 index 7a2e7ce1938a..000000000000 --- a/tests/python/test_mlp_multi_devices.py.bak +++ /dev/null @@ -1,120 +0,0 @@ -# pylint: skip-file -import sys -sys.path.append('../../python/') - -import mxnet as mx -import numpy as np -import os, gzip -import pickle as pickle -import get_data - -# symbol net -data = mx.symbol.Variable('data') -fc1 = mx.symbol.FullyConnected(data = data, name='fc1', nb_hidden=128) -act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") -fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', nb_hidden = 64) -act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") -fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', nb_hidden=10) -mlp = mx.symbol.Softmax(data = fc3, name = 'mlp') - -# use multiple devices -num_devs = 2 -devs = [mx.Context('cpu', i) for i in range(num_devs)] - -# infer shape -batch_size = 100 -input_shape = (batch_size / num_devs, 784) -param_shapes, out_shapes, aux_shapes = mlp.infer_shape(data=input_shape) -param_names = mlp.list_arguments() - -# allocate memory -params = [[mx.narray.create(s, d) for s in param_shapes] for d in devs]; -grads = [[mx.narray.create(s, d) for s in param_shapes] for d in devs]; - -# only need to init param on device 0 -mx.kvstore.init_devices(devs) -sync_keys = [i for i,m in enumerate(param_names) if "weight" in m or "bias" in m] -np.random.seed(0) -for k in sync_keys: - if "weight" in param_names[k]: - params[0][k].numpy[:, :] = np.random.uniform(-0.07, 0.07, v.numpy.shape) - else: - params[0][k].numpy[:] = 0 -mx.kvstore.init((k,params[0][k]) for k in sync_keys) - -# register param updater -def make_updater(env): - def updater(grad, weight): - eta = env['lr'] / sqrt(env['iter']) / env['batch_size'] - env['iter'] += 1 - weight[:] -= eta * grad - return updater - -mx.kvstore.register(make_updater( - {'lr' : 0.1, 'batch_size' : batch_size, 'wd' : .00004})) - -# create exector for each device - -req = ['write_to' for i in range(len(param_names))] -executors = [mlp.bind(devs[i], params[i], grads[i], req) for i in range(num_devs)] -forward_out = [mx.narray.create(e.heads()[0].shape) for e in executors] - -# data reader -get_data.GetMNIST_ubyte() -train_dataiter = mx.io.MNISTIter( - image="data/train-images-idx3-ubyte", - label="data/train-labels-idx1-ubyte", - batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10) -val_dataiter = mx.io.MNISTIter( - image="data/t10k-images-idx3-ubyte", - label="data/t10k-labels-idx1-ubyte", - batch_size=batch_size, shuffle=True, flat=True, silent=False) - -def cal_acc(out, label): - pred = np.argmax(out, axis=1) - return np.sum(pred == label) * 1.0 / out.shape[0] - -def test_mlp(): - epoch = 9 - acc_train = 0. - acc_val = 0. - for i in range(epoch): - # train - print("Epoch %d" % i) - train_acc = 0.0 - for data, label in train_dataiter: - data = data.numpy - label = label.numpy.flatten() - k = batch_size / num_devs - - for d in range(num_devs): - # feed input - idx = range(d*k, (d+1)*k) - params[d][param_names.index('data')].numpy[:] = data[idx,:] - params[d][param_names.index('mlp_label')].numpy[:] = label[idx] - - # pull weight - mx.kvstore.pull((k,params[d][k]) for k in sync_keys) - - # forward and backward - executors[d].forward() - executors[d].heads()[0].copyto(forward_out[d]) - executors[d].backward([forward_out[d]]) - - # push gradient - mx.kvstore.push((k, grads[d][k]) for k in sync_keys) - - # evaluate. cannot put into the above for loop since it is blocked - # until all forwards are finished - for d in range(num_devs): - train_acc += cal_acc(forward_out[d].numpy, label[range(d*k, (d+1)*k)]) - - train_acc /= train_nbatch - train_nbatch += 1 - print("Train Acc: ", train_acc) - train_dataiter.reset() - - assert(acc_train > 0.98) - -if __name__ == "__main__": - test_mlp() diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index bd635e980297..dad0ef0f1db5 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -18,11 +18,7 @@ num_round = 4 prefix = './mlp' -model = mx.model.FeedForward(softmax, - [mx.cpu(i) for i in range(2)], - num_round=num_round, - learning_rate=0.01, wd=0.0004, - momentum=0.9) + #check data get_data.GetMNIST_ubyte() @@ -44,10 +40,17 @@ def test_mlp(): console.setLevel(logging.DEBUG) logging.getLogger('').addHandler(console) - model.fit(X=train_dataiter, - eval_data=val_dataiter, - iter_end_callback=mx.model.do_checkpoint(prefix)) - logging.info('Finish fit...') + model = mx.model.FeedForward.create( + softmax, + X=train_dataiter, + eval_data=val_dataiter, + iter_end_callback=mx.model.do_checkpoint(prefix), + ctx=[mx.cpu(i) for i in range(2)], + num_round=num_round, + learning_rate=0.01, wd=0.0004, + momentum=0.9) + + logging.info('Finish traning...') prob = model.predict(val_dataiter) logging.info('Finish predict...') val_dataiter.reset() @@ -69,6 +72,9 @@ def test_mlp(): assert np.sum(np.abs(prob - prob3)) == 0 # save model explicitly + + + model.save(prefix, 128) model4 = mx.model.FeedForward.load(prefix, 128) prob4 = model4.predict(val_dataiter) From 7c8d4eb0d50d32cc5d6f8362195479659d7d2687 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 18:52:38 -0700 Subject: [PATCH 14/20] [KVStore] BUGFIX GPU merge --- src/kvstore/kvstore_local.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index b08ea9997369..f8cbf53d27b3 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -119,7 +119,7 @@ class KVStoreLocal : public KVStore { } else { CHECK_EQ(ctx.dev_mask(), gpu::kDevMask); NDArray *copy_buf = buf.AllocCopyBuf(ctx.dev_id, val[0].shape()); - CopyFromTo(val[0], copy_buf); + CopyFromTo(val[i], copy_buf); buf.merged += *copy_buf; } } From 91abecdb8e770dfe201cd287c08b55c76fe1bef1 Mon Sep 17 00:00:00 2001 From: Mu Li Date: Mon, 21 Sep 2015 22:43:41 -0400 Subject: [PATCH 15/20] add results for multi-gpu cifar --- example/cifar10/README.md | 86 ++++++++++++++++++++++++++++++++++++++ example/cifar10/cifar10.py | 63 ++-------------------------- 2 files changed, 89 insertions(+), 60 deletions(-) create mode 100644 example/cifar10/README.md diff --git a/example/cifar10/README.md b/example/cifar10/README.md new file mode 100644 index 000000000000..fdc2d1916e9f --- /dev/null +++ b/example/cifar10/README.md @@ -0,0 +1,86 @@ + +Machine: Dual Xeon E5-1650 3.5GHz, 4 GTX 980, Cuda 6.5 + +run `cifar.py`: + + +| | 1 GPU | 2 GPUs | 4 GPUs | +| --- | --- | --- | --- | +| cxxnet | 362 img/sec | 675 img/sec | 1282 img/sec | +| mxnet | 420 img/sec | 804 img/sec | 1436 img/sec | + +sample output + +``` +~/mxnet/example/cifar10 $ python cifar10.py +INFO:root:Start training with 4 devices +Start training with 4 devices +INFO:root:Iteration[0] Train-accuracy=0.507613 +Iteration[0] Train-accuracy=0.507613 +INFO:root:Iteration[0] Time cost=34.800 +Iteration[0] Time cost=34.800 +INFO:root:Iteration[0] Validation-accuracy=0.641021 +Iteration[0] Validation-accuracy=0.641021 +INFO:root:Iteration[1] Train-accuracy=0.679408 +Iteration[1] Train-accuracy=0.679408 +INFO:root:Iteration[1] Time cost=34.481 +Iteration[1] Time cost=34.481 +INFO:root:Iteration[1] Validation-accuracy=0.720152 +Iteration[1] Validation-accuracy=0.720152 +INFO:root:Iteration[2] Train-accuracy=0.740825 +Iteration[2] Train-accuracy=0.740825 +INFO:root:Iteration[2] Time cost=34.463 +Iteration[2] Time cost=34.463 +INFO:root:Iteration[2] Validation-accuracy=0.755709 +Iteration[2] Validation-accuracy=0.755709 +``` + +results from cxxnet for reference + +``` +CXXNET Result: +step1: wmat_lr = 0.05, bias_lr = 0.1, mom = 0.9 +[1] train-error:0.452865 val-error:0.3614 +[2] train-error:0.280231 val-error:0.2504 +[3] train-error:0.220968 val-error:0.2456 +[4] train-error:0.18746 val-error:0.2145 +[5] train-error:0.165221 val-error:0.1796 +[6] train-error:0.150056 val-error:0.1786 +[7] train-error:0.134571 val-error:0.157 +[8] train-error:0.122582 val-error:0.1429 +[9] train-error:0.113891 val-error:0.1398 +[10] train-error:0.106458 val-error:0.1469 +[11] train-error:0.0985054 val-error:0.1447 +[12] train-error:0.0953684 val-error:0.1494 +[13] train-error:0.0872962 val-error:0.1311 +[14] train-error:0.0832401 val-error:0.1544 +[15] train-error:0.0773857 val-error:0.1268 +[16] train-error:0.0743087 val-error:0.125 +[17] train-error:0.0714114 val-error:0.1189 +[18] train-error:0.066616 val-error:0.1424 +[19] train-error:0.0651175 val-error:0.1322 +[20] train-error:0.0616808 val-error:0.111 +step2: lr = 0.01, bias_lr = 0.02, mom = 0.9 +[21] train-error:0.033368 val-error:0.0907 +[22] train-error:0.0250959 val-error:0.0876 +[23] train-error:0.0220388 val-error:0.0867 +[24] train-error:0.0195812 val-error:0.0848 +[25] train-error:0.0173833 val-error:0.0872 +[26] train-error:0.0154052 val-error:0.0878 +[27] train-error:0.0141264 val-error:0.0863 +[28] train-error:0.0134071 val-error:0.0865 +[29] train-error:0.0116688 val-error:0.0878 +[30] train-error:0.0106298 val-error:0.0873 +step3: lr = 0.001, bias_lr = 0.002, mom = 0.9 +[31] train-error:-nan val-error:0.0873 +[31] train-error:0.0067735 val-error:0.0859 +[32] train-error:0.0049952 val-error:0.0835 +[33] train-error:0.00485534 val-error:0.0849 +[34] train-error:0.00367647 val-error:0.0839 +[35] train-error:0.0034367 val-error:0.0844 +[36] train-error:0.00275735 val-error:0.084 +[37] train-error:0.00221787 val-error:0.083 +[38] train-error:0.00171835 val-error:0.0838 +[39] train-error:0.00125879 val-error:0.0833 +[40] train-error:0.000699329 val-error:0.0842 +``` diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index db35927196bc..9ce01c3ee2c0 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -12,57 +12,6 @@ import numpy as np import copy - -""" -CXXNET Result: -step1: wmat_lr = 0.05, bias_lr = 0.1, mom = 0.9 -[1] train-error:0.452865 val-error:0.3614 -[2] train-error:0.280231 val-error:0.2504 -[3] train-error:0.220968 val-error:0.2456 -[4] train-error:0.18746 val-error:0.2145 -[5] train-error:0.165221 val-error:0.1796 -[6] train-error:0.150056 val-error:0.1786 -[7] train-error:0.134571 val-error:0.157 -[8] train-error:0.122582 val-error:0.1429 -[9] train-error:0.113891 val-error:0.1398 -[10] train-error:0.106458 val-error:0.1469 -[11] train-error:0.0985054 val-error:0.1447 -[12] train-error:0.0953684 val-error:0.1494 -[13] train-error:0.0872962 val-error:0.1311 -[14] train-error:0.0832401 val-error:0.1544 -[15] train-error:0.0773857 val-error:0.1268 -[16] train-error:0.0743087 val-error:0.125 -[17] train-error:0.0714114 val-error:0.1189 -[18] train-error:0.066616 val-error:0.1424 -[19] train-error:0.0651175 val-error:0.1322 -[20] train-error:0.0616808 val-error:0.111 -step2: lr = 0.01, bias_lr = 0.02, mom = 0.9 -[21] train-error:0.033368 val-error:0.0907 -[22] train-error:0.0250959 val-error:0.0876 -[23] train-error:0.0220388 val-error:0.0867 -[24] train-error:0.0195812 val-error:0.0848 -[25] train-error:0.0173833 val-error:0.0872 -[26] train-error:0.0154052 val-error:0.0878 -[27] train-error:0.0141264 val-error:0.0863 -[28] train-error:0.0134071 val-error:0.0865 -[29] train-error:0.0116688 val-error:0.0878 -[30] train-error:0.0106298 val-error:0.0873 -step3: lr = 0.001, bias_lr = 0.002, mom = 0.9 -[31] train-error:-nan val-error:0.0873 -[31] train-error:0.0067735 val-error:0.0859 -[32] train-error:0.0049952 val-error:0.0835 -[33] train-error:0.00485534 val-error:0.0849 -[34] train-error:0.00367647 val-error:0.0839 -[35] train-error:0.0034367 val-error:0.0844 -[36] train-error:0.00275735 val-error:0.084 -[37] train-error:0.00221787 val-error:0.083 -[38] train-error:0.00171835 val-error:0.0838 -[39] train-error:0.00125879 val-error:0.0833 -[40] train-error:0.000699329 val-error:0.0842 -""" - -np.random.seed(1812) - conv_cnt = 1 concat_cnt = 1 pool_cnt = 1 @@ -130,13 +79,6 @@ def SimpleFactory(data, ch_1x1, ch_3x3): concat_cnt += 1 return concat -def RandomInit(narray): - in_num = narray.shape[1] - out_num = narray.shape[0] - a = np.sqrt(3.0 / (in_num + out_num)) - tmp = mx.nd.array(np.random.uniform(-a, a, narray.shape)) - narray[:] = tmp - data = mx.symbol.Variable(name="data") conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu") in3a = SimpleFactory(conv1, 32, 32) @@ -153,11 +95,13 @@ def RandomInit(narray): flatten = mx.symbol.Flatten(data=pool, name="flatten1") fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1") loss = mx.symbol.Softmax(data=fc, name="loss") + ######################################################### + get_data.GetCifar10() batch_size = 128 epoch = 3 -num_gpus = 1 +num_gpus = 4 train_dataiter = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", @@ -181,7 +125,6 @@ def test_cifar(): console = logging.StreamHandler() console.setLevel(logging.DEBUG) logging.getLogger('').addHandler(console) - # get model from symbol gpus = [mx.gpu(i) for i in range(num_gpus)] model = mx.model.FeedForward(ctx=gpus, symbol=loss, num_round = epoch, learning_rate=0.05, momentum=0.9, wd=0.00001) From 6207fa4f37a22059df7a3165060b8010d14ae5a5 Mon Sep 17 00:00:00 2001 From: Mu Li Date: Mon, 21 Sep 2015 22:45:11 -0400 Subject: [PATCH 16/20] use 1 gpu in default --- example/cifar10/cifar10.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index 9ce01c3ee2c0..7944985caa4c 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -101,7 +101,7 @@ def SimpleFactory(data, ch_1x1, ch_3x3): get_data.GetCifar10() batch_size = 128 epoch = 3 -num_gpus = 4 +num_gpus = 1 train_dataiter = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", From 57477d3829226ff032da3ef90c8968735da43d6d Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 19:26:45 -0700 Subject: [PATCH 17/20] enable callable to metric --- doc/python/index.md | 4 ++-- doc/python/model.md | 33 +++++++++++++++++---------------- doc/python/symbol.md | 5 +++-- python/mxnet/metric.py | 28 ++++++++++++++++++++++++---- python/mxnet/model.py | 14 +++++++++----- tests/python/train/test_mlp.py | 7 +++++++ 6 files changed, 62 insertions(+), 29 deletions(-) diff --git a/doc/python/index.md b/doc/python/index.md index 82e1eaaa9ed1..dcbcfd57e70d 100644 --- a/doc/python/index.md +++ b/doc/python/index.md @@ -18,8 +18,8 @@ Code Examples Python API Documents -------------------- +* [High Level Model Training Related API](model.md) * [NDArray API](ndarray.md) * [Symbolic API](symbol.md) * [KVStore API](kvstore.md) -* [Data Loading API](io.md) -* [Model API](model.md) \ No newline at end of file +* [Data Loading API](io.md) \ No newline at end of file diff --git a/doc/python/model.md b/doc/python/model.md index bd15379eeeee..686d698dec0e 100644 --- a/doc/python/model.md +++ b/doc/python/model.md @@ -4,12 +4,12 @@ The model API in mxnet as not really an API. It is a thin wrapper build on top of [ndarray](ndarray.md) and [symbolic](symbol.md) modules to make neural network training easy. -* [Train a Model](#overloaded-operators) introduces operator overloading of symbols -* [Serialization](#serialization) introduces how to save and load symbols. -* [Multiple Outputs](#multiple-outputs) introduces how to configure multiple outputs -* [API Reference](#api-reference) gives reference to all functions. -* [Symbol Object Document](#mxnet.symbol.Symbol) gives API reference to the Symbol Object. - +* [Train a Model](#train-a-model) introduces basic training. +* [Save the Model](#save-the-model) +* [Periodically Checkpoint](#periodically-checkpoint) +* [Initializer API Reference](#initializer-api-reference) +* [Evaluation Metric API Reference](#initializer-api-reference) +* [Optimizer API Reference](#optimizer-api-reference) Train a Model ------------- @@ -18,13 +18,13 @@ then call ```model.Feedforward.create``` to create a model for you. The following example creates a two layer neural networks. ```python -batch_size = 100 +# configure a two layer neuralnetwork data = mx.symbol.Variable('data') fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128) -act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") -fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64) -softmax = mx.symbol.Softmax(fc2, name = 'sm') - +act1 = mx.symbol.Activation(fc1, name='relu1', act_type='relu') +fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=64) +softmax = mx.symbol.Softmax(fc2, name='sm') +# create a model model = mx.model.FeedForward.create( softmax, X=data_set, @@ -44,10 +44,11 @@ We also provide a save and load function. ```python # save a model to mymodel-symbol.json and mymodel-0100.params prefix = 'mymodel' -model.save(prefix, 100) +iteration = 100 +model.save(prefix, iteration) # load model back -model_loaded = mx.model.FeedForward.load(prefix, 100) +model_loaded = mx.model.FeedForward.load(prefix, iteration) ``` The advantage of this save and load function is they are language agnostic, and you should be able to save and load directly into cloud storage such as S3 and HDFS. @@ -55,7 +56,7 @@ and you should be able to save and load directly into cloud storage such as S3 a Periodically Checkpoint ----------------------- It is also helpful to periodically checkpoint your model after each iteration. -To do so, you can simply add a checkpoint callback to the function. +To do so, you can simply add a checkpoint callback ```do_checkpoint(path)``` to the function. The training process will automatically checkpoint to the specified place after each iteration. @@ -65,8 +66,7 @@ model = mx.model.FeedForward.create( softmax, X=data_set, iter_end_callback=mx.model.do_checkpoint(prefix), - num_round=num_round, - learning_rate=0.01) + ...) ``` You can load the model checkpoint later using ```Feedforward.load```. @@ -82,6 +82,7 @@ model = mx.model.FeedForward.create( ctx=devices, ...) ``` +The training will be done in a data parallel way on the GPUs you specified. Initializer API Reference ------------------------- diff --git a/doc/python/symbol.md b/doc/python/symbol.md index 3b08f72fdd03..51bf46cfb160 100644 --- a/doc/python/symbol.md +++ b/doc/python/symbol.md @@ -3,8 +3,9 @@ MXNet Python Symbolic API * [How to Commpose Symbols](#overloaded-operators) introduces operator overloading of symbols * [Serialization](#serialization) introduces how to save and load symbols. * [Multiple Outputs](#multiple-outputs) introduces how to configure multiple outputs -* [API Reference](#api-reference) gives reference to all functions. -* [Symbol Object Document](#mxnet.symbol.Symbol) gives API reference to the Symbol Object. +* [Symbol Creation API Reference](#symbol-creationapi-reference) gives reference to all functions. +* [Symbol Object Document](#mxnet.symbol.Symbol) gives API reference to the Symbol Object +* [Execution API Reference](#execution-api-reference) tell us on what executor can do. How to Compose Symbols ---------------------- diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index ad0aa55d332a..1fca626db502 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -1,5 +1,6 @@ # pylint: disable=invalid-name """Online evaluation metric module.""" +from .base import string_types import numpy as np class EvalMetric(object): @@ -52,15 +53,34 @@ def update(self, pred, label): self.num_inst += label.size -def create(name): +class CustomMetric(EvalMetric): + """Calculate accuracy""" + def __init__(self, feval): + name = feval.__name__ + if name.find('<') != -1: + name = 'custom(%s)' % name + super(CustomMetric, self).__init__(name) + self._feval = feval + + def update(self, pred, label): + self.sum_metric += self._feval(pred, label) + self.num_inst += 1 + + +def create(metric): """Create an evaluation metric. Parameters ---------- - name : str - The name of the metric + metric : str or callable + The name of the metric, or a function + providing statistics given pred, label NDArray. """ - if name == 'acc' or name == 'accuracy': + if callable(metric): + return CustomMetric(metric) + if not isinstance(metric, string_types): + raise TypeError('metric should either be callable or str') + if metric == 'acc' or metric == 'accuracy': return Accuracy() else: raise ValueError('Cannot find metric %s' % name) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index df07512af64d..e19aa580a0a5 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -542,8 +542,10 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', eval_data : DataIter or numpy.ndarray pair If eval_set is numpy.ndarray pair, it should be (valid_data, valid_label) - eval_metric : function - Evaluation metric function. + eval_metric : metric.EvalMetric or str or callable + The evaluation metric, name of evaluation metric. + Or a customize evaluation function that returns the statistics + based on minibatch. iter_end_callback : callable(iteration, symbol, arg_params, aux_states) A callback that is invoked at end of each iteration. @@ -556,7 +558,7 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', if self.arg_params is None: self._init_params(input_shape) # setup metric - if isinstance(eval_metric, str): + if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) # setup optimizer optimizer = self.optimizer @@ -666,8 +668,10 @@ def create(symbol, X, y=None, ctx=None, eval_data : DataIter or numpy.ndarray pair If eval_set is numpy.ndarray pair, it should be (valid_data, valid_label) - eval_metric : function - Evaluation metric function. + eval_metric : metric.EvalMetric or str or callable + The evaluation metric, name of evaluation metric. + Or a customize evaluation function that returns the statistics + based on minibatch. iter_end_callback : callable(iteration, symbol, arg_params, aux_states) A callback that is invoked at end of each iteration. diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index dad0ef0f1db5..5ad44fe0350b 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -16,6 +16,12 @@ fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10) softmax = mx.symbol.Softmax(fc3, name = 'sm') +def accuracy(pred, label): + pred = pred.asnumpy() + label = label.asnumpy().astype('int32') + py = np.argmax(pred, axis=1) + return np.sum(py == label) / float(label.size) + num_round = 4 prefix = './mlp' @@ -44,6 +50,7 @@ def test_mlp(): softmax, X=train_dataiter, eval_data=val_dataiter, + eval_metric=accuracy, iter_end_callback=mx.model.do_checkpoint(prefix), ctx=[mx.cpu(i) for i in range(2)], num_round=num_round, From 32060d275c32fa252c0739e8b0c6cebfc2e13a1c Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 21:04:43 -0700 Subject: [PATCH 18/20] Update error handling --- mshadow | 2 +- python/mxnet/metric.py | 2 +- src/engine/naive_engine.cc | 20 +++----------------- src/engine/stream_manager.h | 9 +-------- src/engine/threaded_engine.h | 2 +- src/engine/threaded_engine_perdevice.cc | 18 ++---------------- src/resource.cc | 18 ++---------------- 7 files changed, 11 insertions(+), 60 deletions(-) diff --git a/mshadow b/mshadow index c6f53473ee4b..bf678e6ac05d 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit c6f53473ee4bfd834bf38cd3ff630e395ff662b4 +Subproject commit bf678e6ac05d5115f92db0b668e4424401f31b14 diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 1fca626db502..6438198ce44d 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -83,4 +83,4 @@ def create(metric): if metric == 'acc' or metric == 'accuracy': return Accuracy() else: - raise ValueError('Cannot find metric %s' % name) + raise ValueError('Cannot find metric %s' % metric) diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 2ad0e6772a0d..bef0722919e0 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -20,14 +20,7 @@ class NaiveEngine final : public Engine { for (size_t i = 0; i < streams_.size(); ++i) { if (streams_[i] != nullptr) { // Catch exception for CUDA driver shutdown - try { - mshadow::DeleteStream(streams_[i]); - } catch (const dmlc::Error &e) { - std::string what = e.what(); - if (what.find("driver shutting down") == std::string::npos) { - LOG(ERROR) << "Ignore Error " << what << " during worker finalization"; - } - } + MSHADOW_CATCH_ERROR(mshadow::DeleteStream(streams_[i])); streams_[i] = nullptr; } } @@ -63,14 +56,7 @@ class NaiveEngine final : public Engine { if (exec_ctx.dev_mask() == gpu::kDevMask) { #if MXNET_USE_CUDA size_t dev_id = static_cast(exec_ctx.dev_id); - try { - mshadow::SetDevice(exec_ctx.dev_id); - } catch (const dmlc::Error &e) { - std::string what = e.what(); - if (what.find("driver shutting down") == std::string::npos) { - LOG(ERROR) << "Ignore Error " << what << " during worker finalization"; - } - } + MSHADOW_CATCH_ERROR(mshadow::SetDevice(exec_ctx.dev_id)); if (streams_.size() <= dev_id) { streams_.resize(dev_id + 1, nullptr); } @@ -78,7 +64,7 @@ class NaiveEngine final : public Engine { streams_[dev_id] = mshadow::NewStream(true, MXNET_USE_CUDNN != 0); } ctx_.stream = streams_[dev_id]; - exec_fun(ctx_, callback); + MSHADOW_CATCH_ERROR(exec_fun(ctx_, callback)); #else LOG(FATAL) << "GPU is not enabled"; #endif diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index dbaf9bb8dce6..313db6d2010b 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -120,14 +120,7 @@ void StreamManager::Finalize() { if (gpu_cnt_.at(i) != -1) { for (auto&& j : gpu_streams_.at(i)) { // Catch exception for CUDA driver shutdown - try { - mshadow::DeleteStream(j); - } catch (const dmlc::Error &e) { - std::string what = e.what(); - if (what.find("driver shutting down") == std::string::npos) { - LOG(ERROR) << "Ignore Error " << what << " during worker finalization"; - } - } + MSHADOW_CATCH_ERROR(mshadow::DeleteStream(j)); } gpu_cnt_.at(i) = -1; } diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index fa29939d291f..b8fad8e61a62 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -246,7 +246,7 @@ class ThreadedEngine : public Engine { ThreadedOpr* threaded_opr = opr_block->opr; CallbackOnComplete callback = this->CreateCallback( ThreadedEngine::OnCompleteStatic, threaded_opr); - threaded_opr->fn(run_ctx, callback); + MSHADOW_CATCH_ERROR(threaded_opr->fn(run_ctx, callback)); OprBlock::Delete(opr_block); } diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index c58352def413..c2848e36d831 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -48,14 +48,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) { if (ctx.dev_mask() == gpu::kDevMask) { #if MXNET_USE_CUDA - try { - mshadow::SetDevice(ctx.dev_id); - } catch (const dmlc::Error &e) { - std::string what = e.what(); - if (what.find("driver shutting down") == std::string::npos) { - LOG(ERROR) << "Ignore Error " << what << " during worker finalization"; - } - } + MSHADOW_CATCH_ERROR(mshadow::SetDevice(ctx.dev_id)); #endif } RunContext run_ctx; @@ -147,14 +140,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { this->ExecuteOprBlock(run_ctx, opr_block); } // Catch exception for CUDA driver shutdown - try { - mshadow::DeleteStream(stream); - } catch (const dmlc::Error &e) { - std::string what = e.what(); - if (what.find("driver shutting down") == std::string::npos) { - LOG(ERROR) << "Ignore Error " << what << " during worker finalization"; - } - } + MSHADOW_CATCH_ERROR(mshadow::DeleteStream(stream)); #endif } /*! diff --git a/src/resource.cc b/src/resource.cc index 27a06fd358b5..1bfea3940953 100644 --- a/src/resource.cc +++ b/src/resource.cc @@ -112,14 +112,7 @@ class ResourceManagerImpl : public ResourceManager { mshadow::Random *r = prnd; Engine::Get()->DeleteVariable( [r](RunContext rctx) { - try { - delete r; - } catch (const dmlc::Error &e) { - std::string what = e.what(); - if (what.find("driver shutting down") == std::string::npos) { - LOG(ERROR) << "Ignore Error " << what << " resource finalization"; - } - } + MSHADOW_CATCH_ERROR(delete r); }, ctx, resource.var); } // set seed to a PRNG @@ -160,14 +153,7 @@ class ResourceManagerImpl : public ResourceManager { mshadow::TensorContainer* r = space[i]; Engine::Get()->DeleteVariable( [r](RunContext rctx){ - try { - r->Release(); - } catch (const dmlc::Error &e) { - std::string what = e.what(); - if (what.find("driver shutting down") == std::string::npos) { - LOG(ERROR) << "Ignore Error " << what << " resource finalization"; - } - } + MSHADOW_CATCH_ERROR(r->Release()); }, ctx, resource[i].var); } } From 0a3438528ef50750c1f5095b7836c86a09a40811 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 22:03:12 -0700 Subject: [PATCH 19/20] [SYMBOL] BUGFIX in DFSVisit --- src/symbol/symbol.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index fb2377dbb6b2..3515f8918bad 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -59,22 +59,24 @@ inline void Symbol::DFSVisit(FVisit fvisit) const { std::unordered_set visited; // put the head into the graph for (auto &head : heads_) { - Node *ptr = head.source.get(); + Node* ptr = head.source.get(); if (visited.count(ptr) == 0) { stack.push_back(std::make_pair(&head.source, 0)); + visited.insert(ptr); } } while (!stack.empty()) { std::pair *, uint32_t>& back = stack.back(); if (back.second == back.first->get()->inputs.size()) { fvisit(*(back.first)); - visited.insert(back.first->get()); stack.pop_back(); } else { std::vector& inputs = back.first->get()->inputs; Symbol::DataEntry& input = inputs.at(back.second++); - if (visited.count(input.source.get()) == 0) { + Node* ptr = input.source.get(); + if (visited.count(ptr) == 0) { stack.push_back(std::make_pair(&input.source, 0)); + visited.insert(ptr); } } } From 289a4f14fc8dc057447334c8c63d796ae72fc422 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 23:15:10 -0700 Subject: [PATCH 20/20] [SYMBOL] Add get internals --- include/mxnet/c_api.h | 19 +++++++++++++++++ include/mxnet/symbolic.h | 6 ++++++ python/mxnet/symbol.py | 21 +++++++++++++++++++ src/c_api.cc | 19 +++++++++++++++++ src/symbol/symbol.cc | 31 ++++++++++++++++++++++++++-- tests/python/unittest/test_symbol.py | 15 ++++++++++++++ 6 files changed, 109 insertions(+), 2 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 972081379dc0..f6bf4e5ad862 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -435,6 +435,25 @@ MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol, MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); +/*! + * \brief Get a symbol that contains all the internals. + * \param symbol The symbol + * \param out The output symbol whose outputs are all the internals. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolGetInternals(SymbolHandle symbol, + SymbolHandle *out); +/*! + * \brief Get index-th outputs of the symbol. + * \param symbol The symbol + * \param index the Index of the output. + * \param out The output symbol whose outputs are all the internals. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol, + mx_uint index, + SymbolHandle *out); + /*! * \brief List auxiliary states in the symbol. * \param symbol the symbol diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index edf3e30cab2b..c3f6d05dbb9a 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -101,6 +101,12 @@ class Symbol { */ Symbol operator () (const std::unordered_map& kwargs, const std::string& name) const; + /* + * \brief Get all the internal nodes of the symbol. + * \return symbol A new symbol whose output contains all the outputs of the symbols + * Including input variables and intermediate outputs. + */ + Symbol GetInternals() const; /*! * \brief get the gradient graph * \param wrt with respect to the input diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 44318c66200c..e4f795ec1795 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -153,6 +153,27 @@ def _compose(self, *args, **kwargs): check_call(_LIB.MXSymbolCompose( self.handle, name, num_args, keys, args)) + def __getitem__(self, index): + if not isinstance(index, int): + raise TypeError('Symbol only support integer index to fetch i-th output') + handle = SymbolHandle() + check_call(_LIB.MXSymbolGetOutput( + self.handle, mx_uint(index), ctypes.byref(handle))) + return Symbol(handle=handle) + + def get_internals(self): + """Get a new grouped symbol whose output contains all the internal outputs of this symbol. + + Returns + ------- + sgroup : Symbol + The internal of the symbol. + """ + handle = SymbolHandle() + check_call(_LIB.MXSymbolGetInternals( + self.handle, ctypes.byref(handle))) + return Symbol(handle=handle) + def list_arguments(self): """List all the arguments in the symbol. diff --git a/src/c_api.cc b/src/c_api.cc index d8d3b34ea386..5787ac877f4d 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -509,6 +509,25 @@ int MXSymbolCreateGroup(mx_uint num_symbols, API_END_HANDLE_ERROR(delete s); } +int MXSymbolGetOutput(SymbolHandle symbol, + mx_uint index, + SymbolHandle *out) { + Symbol *s = new Symbol(); + API_BEGIN(); + *s = (*static_cast(symbol))[index]; + *out = s; + API_END_HANDLE_ERROR(delete s); +} + +int MXSymbolGetInternals(SymbolHandle symbol, + SymbolHandle *out) { + Symbol *s = new Symbol(); + API_BEGIN(); + *s = static_cast(symbol)->GetInternals(); + *out = s; + API_END_HANDLE_ERROR(delete s); +} + int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out) { Symbol *s = new Symbol(); API_BEGIN(); diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 3515f8918bad..ee7e19c1cad4 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -46,6 +46,10 @@ struct Symbol::Node { inline bool is_variable() const { return op == nullptr && !backward_source_node; } + /*! \return Whether it is backward op */ + inline bool is_backward() const { + return backward_source_node.get() != nullptr; + } }; /*! \return whwther the symbol is atomic */ @@ -202,9 +206,13 @@ std::vector Symbol::ListOutputs() const { if (head.source->is_variable()) { ret.push_back(head.source->name); } else { - // TODO(bing) rethink about output naming auto &hname = head.source->name; - std::string rname = head.source->op->ListOutputs()[head.index]; + std::string rname; + if (head.source->is_backward()) { + rname = head.source->backward_source_node->op->ListArguments()[head.index]; + } else { + rname = head.source->op->ListOutputs()[head.index]; + } if (hname.length() == 0) { ret.push_back(std::move(rname)); } else { @@ -248,6 +256,25 @@ Symbol Symbol::operator[] (size_t index) const { } } +Symbol Symbol::GetInternals() const { + Symbol ret; + this->DFSVisit([&ret](const std::shared_ptr &node) { + Node* n = node.get(); + uint32_t nout; + if (n->is_variable()) { + nout = 1; + } else if (n->is_backward()) { + nout = static_cast(n->backward_source_node->inputs.size()); + } else { + nout = n->op->NumVisibleOutputs(); + } + for (uint32_t i = 0; i < nout; ++i) { + ret.heads_.push_back(DataEntry(node, i)); + } + }); + return ret; +} + // create a default variable name inline std::string DefaultVarName(const std::string &op_name, const std::string &arg_name) { diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 4f7f7eb1109f..199d3dfaf7cb 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -30,6 +30,20 @@ def test_symbol_compose(): assert len(multi_out.list_outputs()) == 2 + +def test_symbol_internal(): + data = mx.symbol.Variable('data') + oldfc = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10) + net1 = mx.symbol.FullyConnected(data=oldfc, name='fc2', num_hidden=100) + net1.list_arguments() == ['data', + 'fc1_weight', 'fc1_bias', + 'fc2_weight', 'fc2_bias'] + internal = net1.get_internals() + nmap = {x: i for i, x in enumerate(internal.list_outputs())} + fc1 = internal[nmap['fc1_output']] + assert fc1.list_arguments() == oldfc.list_arguments() + + def test_symbol_pickle(): mlist = [models.mlp2(), models.conv()] data = pkl.dumps(mlist) @@ -50,6 +64,7 @@ def test_symbol_saveload(): if __name__ == '__main__': + test_symbol_internal() test_symbol_basic() test_symbol_compose() test_symbol_saveload()