diff --git a/Jenkinsfile b/Jenkinsfile index df39672c5ed2..e41cb7217de4 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -215,9 +215,9 @@ del /Q *.7z // Python unittest for CPU def python_ut(docker_type) { timeout(time: max_time, unit: 'MINUTES') { - sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/unittest" + sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/unittest" sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/unittest" - sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/train" + sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/train" } } @@ -225,7 +225,7 @@ def python_ut(docker_type) { // both CPU and GPU def python_gpu_ut(docker_type) { timeout(time: max_time, unit: 'MINUTES') { - sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/gpu" + sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/gpu" sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/gpu" } } diff --git a/benchmark/python/sparse_op.py b/benchmark/python/sparse_op.py new file mode 100644 index 000000000000..15ca4df1be73 --- /dev/null +++ b/benchmark/python/sparse_op.py @@ -0,0 +1,228 @@ +import ctypes + +from mxnet.test_utils import * +import scipy.sparse as sp +import os +import time +import argparse + +from mxnet.base import check_call, _LIB +from util import get_data, estimate_density + +parser = argparse.ArgumentParser(description="Benchmark sparse operators", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--num-omp-threads', type=int, default=1, help='number of omp threads to set in MXNet') +args = parser.parse_args() + +# some data information +kdda = { + 'data_mini': 'kdda.t.mini', + 'data_name': 'kdda.t', + 'data_origin_name': 'kdda.t.bz2', + 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2", + 'feature_dim': 20216830, + 'm': 200, + 'batch_size': [64] +} + +avazu = { + 'data_mini': 'avazu-app.t.mini', + 'data_name': 'avazu-app.t', + 'data_origin_name': 'avazu-app.t.bz2', + 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2", + 'feature_dim': 1000000, + 'm': 500, + 'batch_size': [64, 128] +} + + +def measure_cost(repeat, f, *args, **kwargs): + # start bench + start = time.time() + results = [] + for i in range(repeat): + results.append(f(*args, **kwargs)) + for result in results: + result.wait_to_read() + end = time.time() + diff = end - start + return diff / repeat + + +def test_dot_real(data_dict): + def get_iter(path, data_shape, batch_size): + data_train = mx.io.LibSVMIter(data_libsvm=path, + data_shape=data_shape, + batch_size=batch_size) + data_iter = iter(data_train) + return data_iter + + data_dir = os.path.join(os.getcwd(), 'data') + + path = os.path.join(data_dir, data_dict['data_name']) + if not os.path.exists(path): + get_data( + data_dir, + data_dict['data_name'], + data_dict['url'], + data_dict['data_origin_name'] + ) + assert os.path.exists(path) + + k = data_dict['feature_dim'] + m = data_dict['m'] + density = estimate_density(path, data_dict['feature_dim']) + + mini_path = os.path.join(data_dir, data_dict['data_mini']) + if not os.path.exists(mini_path): + os.system("head -n 2000 %r > %r" % (path, mini_path)) + assert os.path.exists(mini_path) + + print "Running Benchmarking on %r data" % data_dict['data_mini'] + for batch_size in data_dict['batch_size']: # iterator through different batch size of choice + print "batch_size is %d" % batch_size + # model + data_shape = (k, ) + train_iter = get_iter(mini_path, data_shape, batch_size) + weight = mx.nd.random_uniform(low=0, high=1, shape=(k, m)) + + csr_data = [] + dns_data = [] + num_batch = 0 + for batch in train_iter: + data = train_iter.getdata() + csr_data.append(data) + dns_data.append(data.todense()) + num_batch += 1 + bag_of_data = [csr_data, dns_data] + num_repeat = 5 + costs = [] + for d in bag_of_data: + weight.wait_to_read() + cost = 0. + count = 0 + for d_batch in d: + d_batch.wait_to_read() + cost += measure_cost(num_repeat, mx.nd.dot, d_batch, weight) + count += 1 + costs.append(cost/count) + t_sparse = costs[0] + t_dense = costs[1] + ratio = t_dense / t_sparse + print('density(%)\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse') + fmt = "%0.4f\t\t%d\t%d\t%d\t%0.2f\t\t\t%0.4f\t%0.6f" + print(fmt % (density * 100, batch_size, m, k, ratio, t_dense, t_sparse)) + + +def test_dot_synthetic(): + """benchmark mx.nd.dot(sparse_ndarray, dense_ndarray) with given density. + `t_sparse` is the time cost of dot(csr, dns), while `t_dense` is the time cost + of dot(dns, dns), with the same matrix except that it is in default storage type. + """ + def measure_cost_forward_baseline(repeat, dot, lhs, rhs): + start = time.time() + for i in range(repeat): + dot(lhs, rhs) + end = time.time() + diff = end - start + return diff / repeat + + def measure_cost_backward_baseline(repeat, dot, transpose, lhs, rhs): + start = time.time() + for i in range(repeat): + dot(transpose(lhs), rhs) + end = time.time() + diff = end - start + return diff / repeat + + def bench_dot_forward(m, k, n, density, ctx, repeat): + set_default_context(ctx) + dns = mx.nd.random_uniform(shape=(k, n)).copyto(ctx) + data_shape = (m, k) + csr_data = rand_ndarray(data_shape, 'csr', density) + dns_data = csr_data.todense() + rhs_dns_np = dns.asnumpy() + lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) # csr in scipy + lhs_dns_np = lhs_csr_sp.todense() + + data = [dns_data, csr_data] + costs = [] + for d in data: + dns.wait_to_read() + d.wait_to_read() + cost = measure_cost(repeat, mx.nd.dot, d, dns) + costs.append(cost) + ratio = costs[0] / costs[1] + + costs_baseline = [] + cost = measure_cost_forward_baseline(repeat, np.dot, lhs_dns_np, rhs_dns_np) + costs_baseline.append(cost) + cost = measure_cost_forward_baseline(repeat, sp.spmatrix.dot, lhs_csr_sp, rhs_dns_np) + costs_baseline.append(cost) + ratio_baseline = costs_baseline[0] / costs_baseline[1] + fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.2f\t\t\t%0.2f\t%0.5f\t\t%0.2f\t\t\t\t%0.6f\t%0.5f" + print(fmt % (density * 100, str(ctx), n, m, k, ratio, costs[0], costs[1], + ratio_baseline, costs_baseline[0], costs_baseline[1])) + + def bench_dot_backward(m, k, n, density, ctx, repeat): + set_default_context(ctx) + dns = mx.nd.random_uniform(shape=(m, n)).copyto(ctx) + data_shape = (m, k) + csr_data = rand_ndarray(data_shape, 'csr', density) + dns_data = csr_data.todense() + rhs_dns_np = dns.asnumpy() + lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) + lhs_dns_np = lhs_csr_sp.todense() + + data = [dns_data, csr_data] + costs = [] + for d in data: + dns.wait_to_read() + d.wait_to_read() + cost = measure_cost(repeat, mx.nd.dot, d, dns, transpose_a=True) + costs.append(cost) + ratio = costs[0] / costs[1] + + costs_baseline = [] + cost = measure_cost_backward_baseline(repeat, np.dot, np.transpose, lhs_dns_np, rhs_dns_np) + costs_baseline.append(cost) + cost = measure_cost_backward_baseline(repeat, sp.spmatrix.dot, sp.spmatrix.transpose, lhs_csr_sp, rhs_dns_np) + costs_baseline.append(cost) + ratio_baseline = costs_baseline[0] / costs_baseline[1] + fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.2f\t\t\t%0.2f\t%0.5f\t\t%0.2f\t\t\t\t%0.6f\t%0.5f" + print(fmt % (density * 100, str(ctx), n, m, k, ratio, costs[0], costs[1], + ratio_baseline, costs_baseline[0], costs_baseline[1])) + + print("A = sparse NDArray of shape(m, k)") + print("B = dense NDArray of shape(k, n)") + print("dot_forward\tdot(csr, dns)") + print('density(%)\tcontext\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse' + '\tt_scipy_dense/t_scipy_sparse\tt_scipy_dense\tt_scipy_sparse') + + check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads))) + # TODO(haibin) make these runtime options + m = 512 + k = [50000, 100000] + n = [64, 128] + density = [1.00, 0.90, 0.70, 0.50, 0.30, 0.20, 0.10, 0.07, 0.05, 0.02, 0.01, 0.005, 0.001] + num_repeat = 10 + # contexts = [mx.cpu(), mx.gpu(0)] + contexts = [mx.cpu()] + for i in range(2): + for ctx in contexts: + for den in density: + bench_dot_forward(m, k[i], n[i], den, ctx, num_repeat) + + print("dot_backward\tdot(csr.T, dns)") + print('density(%)\tcontext\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse' + '\tt_scipy_dense/t_scipy_sparse\tt_scipy_dense\tt_scipy_sparse') + for i in range(2): + for ctx in contexts: + for den in density: + bench_dot_backward(m, k[i], n[i], den, ctx, num_repeat) + + +if __name__ == "__main__": + test_dot_real(avazu) + test_dot_real(kdda) + test_dot_synthetic() diff --git a/benchmark/python/util.py b/benchmark/python/util.py new file mode 100644 index 000000000000..86e67d0f8a20 --- /dev/null +++ b/benchmark/python/util.py @@ -0,0 +1,33 @@ +import os +import random + + +def get_data(data_dir, data_name, url, data_origin_name): + if not os.path.isdir(data_dir): + os.system("mkdir " + data_dir) + os.chdir(data_dir) + if (not os.path.exists(data_name)): + import urllib + zippath = os.path.join(data_dir, data_origin_name) + urllib.urlretrieve(url, zippath) + os.system("bzip2 -d %r" % data_origin_name) + os.chdir("..") + + +def estimate_density(DATA_PATH, feature_size): + """sample 10 times of a size of 1000 for estimating the density of the sparse dataset""" + if not os.path.exists(DATA_PATH): + raise Exception("Data is not there!") + density = [] + P = 0.01 + for _ in xrange(10): + num_non_zero = 0 + num_sample = 0 + with open(DATA_PATH) as f: + for line in f: + if (random.random() < P): + num_non_zero += len(line.split(" ")) - 1 + num_sample += 1 + density.append(num_non_zero * 1.0 / (feature_size * num_sample)) + return sum(density) / len(density) + diff --git a/dmlc-core b/dmlc-core index a6c5701219e6..3f919c0d850c 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit a6c5701219e635fea808d264aefc5b03c3aec314 +Subproject commit 3f919c0d850cab959aada246dcf305c9b6ab5a7d diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 90270f776456..b1ae3e70bb70 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -246,6 +246,38 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape, int delay_alloc, int dtype, NDArrayHandle *out); + + +/*! + * \brief create an empty sparse NDArray with specified shape and data type + * \param storage_type the storage type of the ndarray + * \param shape the pointer to the shape + * \param ndim the dimension of the shape + * \param dev_type device type, specify device we want to take + * \param dev_id the device id of the specific device + * \param delay_alloc whether to delay allocation until + * the narray is first mutated + * \param dtype data type of created array + * \param num_aux the number of aux data to support this ndarray + * \param aux_type data type of the aux data for the created array + * \param aux_ndims the dimension of the shapes of aux data + * \param aux_shape the shapes of aux data + * \param out the returning handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type, + const mx_uint *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + mx_uint num_aux, + int *aux_type, + mx_uint *aux_ndims, + const mx_uint *aux_shape, + NDArrayHandle *out); + /*! * \brief create a NDArray handle that is loaded from raw bytes. * \param buf the head of the raw bytes @@ -358,6 +390,7 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle, mx_uint slice_begin, mx_uint slice_end, NDArrayHandle *out); + /*! * \brief Index the NDArray along axis 0. * \param handle the handle to the NDArray @@ -368,6 +401,13 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle, MXNET_DLL int MXNDArrayAt(NDArrayHandle handle, mx_uint idx, NDArrayHandle *out); + +/*! + * \brief get the storage type of the array + */ +MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle, + int *out_storage_type); + /*! * \brief Reshape the NDArray. * \param handle the handle to the narray @@ -406,6 +446,26 @@ MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle, */ MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle, int *out_dtype); + +/*! + * \brief get the type of the ith aux data in NDArray + * \param handle the handle to the narray + * \param i the index of the aux data + * \param out_type pointer holder to get type of aux data + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle, + mx_uint i, + int *out_type); + +// Get the ith aux data blob wrapped in an NDArray +MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle, + mx_uint i, + NDArrayHandle *out); + +// Get the data blob wrapped in an NDArray +MXNET_DLL int MXNDArrayGetDataNDArray(NDArrayHandle handle, + NDArrayHandle *out); /*! * \brief get the context of the NDArray * \param handle the handle to the narray @@ -1003,6 +1063,25 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, mx_uint *aux_type_size, const int **aux_type_data, int *complete); + + + + +/*! + * \brief infer storage type of unknown input types given the known one. + */ +MXNET_DLL int MXSymbolInferStorageType(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const int *arg_storage_type_data, + mx_uint *in_storage_type_size, + const int **in_storage_type_data, + mx_uint *out_storage_type_size, + const int **out_storage_type_data, + mx_uint *aux_storage_type_size, + const int **aux_storage_type_data, + int *complete); + //-------------------------------------------- // Part 4: Executor interface //-------------------------------------------- @@ -1167,6 +1246,9 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, const mx_uint num_provided_arg_dtypes, const char** provided_arg_dtype_names, const int* provided_arg_dtypes, + const mx_uint num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, const mx_uint num_shared_arg_names, const char** shared_arg_name_list, int* shared_buffer_len, @@ -1328,6 +1410,19 @@ MXNET_DLL int MXKVStoreInit(KVStoreHandle handle, const int* keys, NDArrayHandle* vals); +/*! + * \brief Init a list of (key,value) pairs in kvstore, where each key is a string + * \param handle handle to the kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStoreInitEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals); + /*! * \brief Push a list of (key,value) pairs to kvstore * \param handle handle to the kvstore @@ -1342,6 +1437,20 @@ MXNET_DLL int MXKVStorePush(KVStoreHandle handle, const int* keys, NDArrayHandle* vals, int priority); +/*! + * \brief Push a list of (key,value) pairs to kvstore, where each key is a string + * \param handle handle to the kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \param priority the priority of the action + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority); /*! * \brief pull a list of (key, value) pairs from the kvstore * \param handle handle to the kvstore @@ -1356,6 +1465,20 @@ MXNET_DLL int MXKVStorePull(KVStoreHandle handle, const int* keys, NDArrayHandle* vals, int priority); +/*! + * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string + * \param handle handle to the kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \param priority the priority of the action + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority); /*! * \brief user-defined updater for the kvstore * It's this updater's responsibility to delete \a recv and \a local diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index 40bd60f5f405..5856b87cf859 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -115,6 +115,7 @@ class Executor { const std::vector& aux_state_ctxes, const std::unordered_map& arg_shape_map, const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, const std::vector& grad_req_types, const std::unordered_set& param_names, std::vector* in_args, diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index dafaf1bf9cab..a77f653d492c 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -63,6 +63,13 @@ class KVStore { */ virtual void Init(const std::vector& keys, const std::vector& values) = 0; + /*! + * \brief Initialize a list of key-value pair to the store. + * \param keys a list of unique keys in string format + * \param values a list of values + */ + virtual void Init(const std::vector& str_keys, + const std::vector& values) = 0; /*! * \brief push a list of key-value pairs into the store * @@ -102,6 +109,16 @@ class KVStore { virtual void Push(const std::vector& keys, const std::vector& values, int priority = 0) = 0; + + /*! + * \brief push a list of key-value pairs into the store + * \param keys the list of keys in string format + * \param values the list of values + * \param priority Priority of the action. + */ + virtual void Push(const std::vector& str_keys, + const std::vector& values, + int priority = 0) = 0; /*! * \brief pull a list of key-value pairs from the store * @@ -128,6 +145,16 @@ class KVStore { virtual void Pull(const std::vector& keys, const std::vector& values, int priority = 0) = 0; + /*! + * \brief pull a list of key-value pairs from the store + * \param keys the list of keys in string format + * \param values the list of buffers for the pulled data, they should be preallocated + * \param priority Priority of the action. + */ + virtual void Pull(const std::vector& str_keys, + const std::vector& values, + int priority = 0) = 0; + /** * \brief the prototype of user-defined updater diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 504fd5e7676e..e1e6269e3d8b 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -29,7 +29,11 @@ namespace mxnet { -// forward declaration +namespace ndarray { +template +void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, RunContext ctx); +}; + namespace autograd { class AGNode; @@ -53,6 +57,23 @@ class AGNodeEntry { class AutogradRuntime; } // namespace autograd +// enum for storage types +namespace csr { +enum CSRAuxType {kIndPtr, kIdx}; +} + +namespace rowsparse { +enum RowSparseAuxType {kIdx}; +} + +enum NDArrayStorageType { + kUndefinedStorage = -1, // undefined storage + kDefaultStorage, // dense + kRowSparseStorage, // row sparse + kCSRStorage, // csr +}; + + /*! * \brief ndarray interface */ @@ -73,10 +94,55 @@ class NDArray { */ NDArray(const TShape &shape, Context ctx, bool delay_alloc = false, int dtype = mshadow::default_type_flag) - : ptr_(std::make_shared(shape.Size(), ctx, delay_alloc, dtype)), + : ptr_(std::make_shared(shape, ctx, delay_alloc, dtype)), shape_(shape), dtype_(dtype), entry_({nullptr, 0, 0}) { #if MKL_EXPERIMENTAL == 1 Mkl_mem_ = std::make_shared(); +#endif + } + /*! \brief constructor for NDArray with storage type + */ + NDArray(const NDArrayStorageType stype, const TShape &shape, Context ctx, + bool delay_alloc = true, int dtype = mshadow::default_type_flag, + std::vector aux_types = {}, std::vector aux_shapes = {}, + TShape storage_shape = TShape(mshadow::Shape1(0))) + : shape_(shape), dtype_(dtype), entry_({nullptr, 0, 0}) { + // Assign default aux types if not given + if (aux_types.size() == 0) { + if (stype == kRowSparseStorage) { + aux_types = {mshadow::kInt64}; + } else if (stype == kCSRStorage) { + aux_types = {mshadow::kInt64, mshadow::kInt64}; + } else { + LOG(FATAL) << "Unknown storage type " << stype; + } + } + // Assign default shapes if not given + // unknown shapes are intialized as {0} such that Size() would return 0 + if (aux_shapes.size() == 0) { + if (stype == kRowSparseStorage) { + aux_shapes = {TShape(mshadow::Shape1(0))}; + } else if (stype == kCSRStorage) { + // aux shapes for indptr and indices + aux_shapes = {TShape(mshadow::Shape1(0)), TShape(mshadow::Shape1(0))}; + } else { + LOG(FATAL) << "Unknown storage type " << stype; + } + } + if (storage_shape.Size() == 0) { + if (stype == kRowSparseStorage) { + storage_shape = shape; + storage_shape[0] = aux_shapes[rowsparse::kIdx][0]; + } else if (stype == kCSRStorage) { + storage_shape = aux_shapes[csr::kIdx]; + } else { + LOG(FATAL) << "Unknown storage type " << stype; + } + } + ptr_ = std::make_shared(stype, storage_shape, ctx, delay_alloc, + dtype, aux_types, aux_shapes); +#if MKL_EXPERIMENTAL == 1 + Mkl_mem_ = std::make_shared(); #endif } /*! @@ -85,28 +151,116 @@ class NDArray { * make sure the memory region is available through out the life of NDArray * \param data the memory content of static data * \param dev_id the device id this tensor sits at + * \param shared_var the same var handle shared with others. + It will not be deleted during destruction. */ - NDArray(const TBlob &data, int dev_id) - : ptr_(std::make_shared(data, dev_id)), shape_(data.shape_), + NDArray(const TBlob &data, int dev_id, Engine::VarHandle shared_var = nullptr) + : ptr_(std::make_shared(data, dev_id, shared_var)), shape_(data.shape_), dtype_(data.type_flag_), entry_({nullptr, 0, 0}) { #if MKL_EXPERIMENTAL == 1 Mkl_mem_ = std::make_shared(); #endif } + + /*! + * \brief constructing a static NDArray of non-default storage that shares data with TBlob + * Use with caution: allocate ONLY ONE NDArray for each TBlob, + * make sure the memory region is available through out the life of NDArray + * \param stype the storage type of NDArray + * \param shape the shape of NDArray + * \param data the memory content of static data + * \param aux_data the memory content of static aux data + * \param dev_id the device id this tensor sits at + * \param shared_var the same var handle shared with others. + It will not be deleted during destruction. + */ + NDArray(const NDArrayStorageType stype, const TShape &shape, + const TBlob &data, const std::vector &aux_data, int dev_id) + : ptr_(std::make_shared(stype, data, aux_data, dev_id)), shape_(shape), + dtype_(data.type_flag_), entry_({nullptr, 0, 0}) { +#if MKL_EXPERIMENTAL == 1 + Mkl_mem_ = std::make_shared(); +#endif + } + + /*! - * \return the shape of current NDArray + * \return the shape of current NDArray. */ inline const TShape& shape() const { return shape_; } + /*! + * \return the shape of underlying chunk which stores the NDArray values. + * For default storage, it is the same as shape(). For row-sparse storage, it is the shape of + * the tensor which stores the non-zero values. + */ + inline const TShape &storage_shape() const { + CHECK(ptr_ != nullptr); + return ptr_->storage_shape; + } + + /*! + * \brief For sparse operations, the storage shape is an estimated value + * in the beginning for allocating enough capacity for the final result. + * After the operation is done, the exact size of the shape is known + * and need to be reset using this function. For example, adding + * two CSRs with nnz1 and nnz2 as their numbers of non-zero values, respectively, + * would allocate the array of size nnz1+nnz2 first and get the final + * nnz that is smaller than nnz1+nnz2. Therefore, the storage shape's size + * needs to be shrunk from nnz1+nnz2 to nnz. + */ + inline void set_storage_shape(const TShape& sshape) { + CHECK(storage_type() != kDefaultStorage); + ptr_->storage_shape = sshape; + } + + /*! + * \return the shape of aux data at ith index. If it doesn't exist, return an empty one. + */ + inline const TShape aux_shape(size_t i) const { + CHECK(storage_type() != kDefaultStorage); + return ptr_->aux_shapes[i]; + } + + /*! + * \brief For a sparse operation on a csr matrix for example, + * the size of the column index array + * is an estimated value in the beginning for allocating enough capacity + * for the final result. After the operation is done, the exact size of + * the shape is known and need to be reset using this function. + */ + inline void set_aux_shape(size_t i, const TShape& shape) const { + ptr_->aux_shapes[i] = shape; + } + /*! * \return the data TBlob */ inline const TBlob& data() const { - CheckAndAlloc(); + if (storage_type() == kDefaultStorage) CheckAndAlloc(); SetTBlob(); return tblob_; } + /*! + * \return the aux TBlob + */ + inline TBlob aux_data(size_t i) const { + auto stype = storage_type(); + TBlob res; + auto shape = aux_shape(i); + auto type = aux_type(i); + MSHADOW_TYPE_SWITCH(type, DType, { + auto dptr = static_cast(ptr_->aux_handles[i].dptr); + CHECK(stype == kRowSparseStorage || stype == kCSRStorage) + << "Unexpected storage type: " << stype; + res = TBlob(dptr, shape, ptr_->aux_handles[i].ctx.dev_mask(), type); + }); +#if MKL_EXPERIMENTAL == 1 + res.Mkl_mem_ = Mkl_mem_; +#endif + return res; + } /*! * \return the context of NDArray, this function is only valid when the NDArray is not empty */ @@ -119,6 +273,15 @@ class NDArray { inline int dtype() const { return dtype_; } + inline int aux_type(size_t i) const { + CHECK(!is_none()); + return ptr_->aux_types[i]; + } + + inline NDArrayStorageType storage_type() const { + if (is_none()) return kUndefinedStorage; + return ptr_->storage_type; + } /*! \return whether this ndarray is not initialized */ inline bool is_none() const { return ptr_.get() == nullptr; @@ -127,6 +290,18 @@ class NDArray { bool fresh_out_grad() const; /*! \return updated grad state in entry_ */ void set_fresh_out_grad(bool state) const; + // returns true if a sparse ndarray's aux_data and storage are initialized + inline bool storage_initialized() const { + if (is_none()) return false; + auto stype = storage_type(); + CHECK_NE(stype, kDefaultStorage); + if (stype == kRowSparseStorage || stype == kCSRStorage) { + return aux_shape(0).Size() != 0; + } else { + LOG(FATAL) << "Unknown storage type"; + } + return true; + } /*! * \brief Block until all the pending write operations with respect * to current NDArray are finished, and read can be performed. @@ -157,6 +332,12 @@ class NDArray { * \param strm the output stream */ void Save(dmlc::Stream *strm) const; + /*! + * \brief load ndarrays before supporting sparse ndarrays + * \param strm the output stream + * \param magic the magic number used for version control + */ + bool LegacyLoad(dmlc::Stream *strm, const uint32_t magic); /*! * \brief load the content from binary stream * \param strm the output stream @@ -260,17 +441,31 @@ class NDArray { void SyncCopyToCPU(void *data, size_t size) const; /*! * \brief Slice a NDArray - * \param begin begin index in first dim - * \param end end index in first dim + * \param begin begin index in first dim (inclusive) + * \param end end index in first dim (exclusive) * \return sliced NDArray */ NDArray Slice(index_t begin, index_t end) const; + /*! * \brief Index a NDArray * \param idx the index * \return idx-th sub array NDArray */ NDArray At(index_t idx) const; + // Wrap the tblob of aux data into an NDArray which shares the same variable with the + // current one. + inline const NDArray aux_ndarray(size_t i) const { + CHECK_NE(storage_type(), kDefaultStorage); + CHECK(i < ptr_->aux_shapes.size()); + return NDArray(aux_data(i), ctx().dev_id, var()); + } + // Wrap the tblob of data into an NDArray which shares the same variable with the + // current one. + inline const NDArray data_ndarray() const { + CHECK_NE(storage_type(), kDefaultStorage); + return NDArray(data(), ctx().dev_id, var()); + } /*! * \brief Create a NDArray that shares memory with current one * The new array must have smaller memory size than the current array. @@ -279,6 +474,7 @@ class NDArray { * \return NDArray in new shape and type. */ inline NDArray AsArray(const TShape &shape, int dtype) const { + CHECK_EQ(storage_type(), kDefaultStorage) << "Not implemented yet"; CHECK_GE(shape_.Size() * mshadow::mshadow_sizeof(dtype_), shape.Size() * mshadow::mshadow_sizeof(dtype)) << "NDArray.AsArray: target memory size is bigger"; @@ -312,8 +508,25 @@ class NDArray { * This is an internal function used by system that normal user should not use */ inline void CheckAndAlloc() const { + CHECK_EQ(storage_type(), kDefaultStorage); ptr_->CheckAndAlloc(); } + /* ! + * \brief Alloc memory for non-default storage + * aux_shape is only known at run time + */ + inline void CheckAndAlloc(const std::vector &aux_shapes) const { + CHECK_NE(storage_type(), kDefaultStorage); + ptr_->CheckAndAlloc(shape_, aux_shapes, dtype_); + } + inline void CheckAndAllocData(const TShape &storage_shape) const { + CHECK_NE(storage_type(), kDefaultStorage); + ptr_->CheckAndAllocData(storage_shape, dtype_); + } + inline void CheckAndAllocAuxData(size_t i, const TShape &aux_shape) const { + CHECK_NE(storage_type(), kDefaultStorage); + ptr_->CheckAndAllocAuxData(i, aux_shape); + } /*! * \brief Save list of ndarray into the Stream.x * \param fo The stream of output. @@ -336,44 +549,132 @@ class NDArray { private: friend class autograd::AutogradRuntime; /*! \brief the real data chunk that backs NDArray */ + // shandle is used to store the actual values in the NDArray + // aux_handles store the aux data(such as indices) if it's needed by non-default storage. struct Chunk { - /*! \brief storage handlefrom storage engine */ + /*! \brief storage handle from storage engine. + for non-default storage, shandle stores the data(value) array. + */ Storage::Handle shandle; + /*! \brief storage handles for aux data (e.g index) + for row_sparse, aux_handles[0] = indices + for csr, aux_handles[0] = indptr, aux_handles[1] = indices + */ + std::vector aux_handles; /*! \brief variable from engine */ Engine::VarHandle var; /*! * \brief if this is true, this means the data do not come * from Storage, and do not need to be freed */ + /*! \brief construct from static data */ bool static_data; - /*! \brief whether allocation is delayed */ + /*! \brief whether data allocation is delayed. This doesn't indicate whether aux data + allocation is delayed. */ bool delay_alloc; + // the type of the storage. The storage_type is never kUndefinedStorage once the chunk + // is constructed. + NDArrayStorageType storage_type = kDefaultStorage; + /*! \brief type of aux */ + std::vector aux_types; + // context of data + Context ctx; + // The shape of the chunk data. + // This might not be the same shape as the NDArray, since the storage may be sparse. + // The default value for storage_shape is {0} when an empty non-default NDArray is created. + TShape storage_shape; + // The shape of aux data. The default value for the shape depends on the type of storage. + // If aux_shapes[i].Size() is zero, aux data i is empty. + std::vector aux_shapes; + // \brief skip the deletion of var handle. Usually set when shared_var is present. + bool skip_delete_var = false; + /*! \brief default cosntructor */ - Chunk() : static_data(true), delay_alloc(false) { - var = Engine::Get()->NewVariable(); - } - /*! \brief construct from static data */ - Chunk(const TBlob &data, int dev_id) - : static_data(true), - delay_alloc(false) { + Chunk() : static_data(true), delay_alloc(false) {} + + /*! \brief construct a new chunk */ + Chunk(TShape shape, Context ctx_, bool delay_alloc_, int dtype) + : static_data(false), delay_alloc(true), ctx(ctx_) { + auto size = shape.Size(); + storage_shape = shape; var = Engine::Get()->NewVariable(); + shandle.size = size * mshadow::mshadow_sizeof(dtype); + shandle.ctx = ctx_; + if (!delay_alloc_) this->CheckAndAlloc(); + } + + Chunk(const TBlob &data, int dev_id, Engine::VarHandle shared_var) + : static_data(true), delay_alloc(false) { + CHECK(storage_type == kDefaultStorage); + // init var + if (shared_var == nullptr) { + var = Engine::Get()->NewVariable(); + } else { + skip_delete_var = true; + var = shared_var; + } + // init ctx if (data.dev_mask() == cpu::kDevMask) { - shandle.ctx = Context::CPU(); + ctx = Context::CPU(); } else { CHECK_EQ(data.dev_mask(), gpu::kDevMask); - shandle.ctx = Context::GPU(dev_id); + ctx = Context::GPU(dev_id); } + // init shandle + shandle.ctx = ctx; shandle.dptr = data.dptr_; shandle.size = data.shape_.Size() * mshadow::mshadow_sizeof(data.type_flag_); + storage_shape = data.shape_; } - /*! \brief construct a new chunk */ - Chunk(uint64_t size, Context ctx, bool delay_alloc_, int dtype) - : static_data(false), delay_alloc(true) { + // Constructor for a non-default storage chunk + Chunk(NDArrayStorageType storage_type_, const TShape &storage_shape_, Context ctx_, + bool delay_alloc_, int dtype, const std::vector &aux_types_, + const std::vector &aux_shapes_) + : static_data(false), delay_alloc(delay_alloc_), storage_type(storage_type_), + aux_types(aux_types_), ctx(ctx_), storage_shape(storage_shape_), + aux_shapes(aux_shapes_) { + shandle.ctx = ctx; var = Engine::Get()->NewVariable(); - shandle.size = size * mshadow::mshadow_sizeof(dtype); + // aux_handles always reflect the correct number of aux data + for (size_t i = 0; i < aux_shapes.size(); i++) { + CheckAndAllocAuxData(i, aux_shapes[i]); + } + if (!delay_alloc) { + CheckAndAllocData(storage_shape, dtype); + } + } + + Chunk(const NDArrayStorageType storage_type_, const TBlob &data, + const std::vector &aux_data, int dev_id) + : static_data(true), delay_alloc(false), storage_type(storage_type_) { + using namespace mshadow; + CHECK_NE(storage_type, kDefaultStorage); + // init var + var = Engine::Get()->NewVariable(); + // init ctx + if (data.dev_mask() == cpu::kDevMask) { + ctx = Context::CPU(); + } else { + CHECK_EQ(data.dev_mask(), gpu::kDevMask); + ctx = Context::GPU(dev_id); + } + // init shandle shandle.ctx = ctx; - if (!delay_alloc_) this->CheckAndAlloc(); + shandle.dptr = data.dptr_; + shandle.size = data.shape_.Size() * mshadow_sizeof(data.type_flag_); + storage_shape = data.shape_; + // init aux handles + for (const auto &aux : aux_data) { + Storage::Handle aux_handle; + aux_handle.ctx = ctx; + aux_handle.dptr = aux.dptr_; + aux_handle.size = aux.shape_.Size() * mshadow_sizeof(aux.type_flag_); + aux_handles.push_back(aux_handle); + aux_types.emplace_back(aux.type_flag_); + aux_shapes.emplace_back(aux.shape_); + } } + /*! \brief check if delay alloc is on, do alloc if not yet done */ inline void CheckAndAlloc(void) { if (delay_alloc) { @@ -381,22 +682,98 @@ class NDArray { delay_alloc = false; } } - /*! \brief destructor */ - ~Chunk() { - if (static_data || delay_alloc) { - Engine::Get()->DeleteVariable([](RunContext s) {}, shandle.ctx, var); + inline void CheckAndAlloc(const TShape &shape, const std::vector &aux_shapes, + int dtype) { + // calculate size, perform allocation + if (kRowSparseStorage == storage_type) { + // For row sparse, aux_shape indicates the number of rows to allocate + auto aux_shape = aux_shapes[rowsparse::kIdx]; + CHECK_EQ(shape.ndim(), 2) << "High dim RowSparse not yet implemented"; + CheckAndAllocAuxData(rowsparse::kIdx, aux_shape); + TShape storage_shape(shape); + storage_shape[0] = aux_shape[0]; + CheckAndAllocData(storage_shape, dtype); + } else if (kCSRStorage == storage_type) { + CheckAndAllocAuxData(csr::kIndPtr, aux_shapes[csr::kIndPtr]); + CheckAndAllocAuxData(csr::kIdx, aux_shapes[csr::kIdx]); + CheckAndAllocData(aux_shapes[csr::kIdx], dtype); } else { - Storage::Handle h = this->shandle; - Engine::Get()->DeleteVariable([h](RunContext s) { - Storage::Get()->Free(h); - }, shandle.ctx, var); + LOG(FATAL) << "Storage type " << storage_type << " not implemented for CheckAndAlloc"; } } - }; + // create storage handle for data based on shape and dtype, assuming ctx is set + // storage shape is also updated + // if data is already allocated, try reuse the storage. Otherwise, free the current one + // and allocate new storage + inline void CheckAndAllocData(const TShape &shape, int dtype) { + CHECK_NE(aux_shapes.size(), 0) << "data is expected to be allocated after aux_data"; + auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype); + if (shandle.size < dbytes) { + // free storage if necessary and alloc again + if (shandle.size > 0) Storage::Get()->Free(shandle); + // init storage + shandle = Storage::Get()->Alloc(dbytes, ctx); + } + // init shape + storage_shape = shape; + // delay_alloc is only set when data storage handle is present + delay_alloc = false; + } + // create storage handle for aux data based on shape + // this function assumes ctx, aux shapes and aux types are set + // aux shape is also updated + // if aux data is already allocated, try reuse the storage. Otherwise, free the current one + // and allocate new storage + inline void CheckAndAllocAuxData(size_t i, const TShape &shape) { + CHECK_EQ(shape.ndim(), 1) << "shape must be 1D in CheckAndAllocAuxData"; + CHECK_NE(storage_type, kUndefinedStorage) + << "storage type cannot be kUndefinedStorage in CheckAndAllocAuxData"; + CHECK_NE(storage_type, kDefaultStorage) + << "storage type cannot be kDefaultStorage in CheckAndAllocAuxData"; + if (aux_handles.size() <= i) { + aux_handles.resize(i + 1); + } + size_t aux_bytes = shape.Size() * mshadow::mshadow_sizeof(aux_types[i]); + if (aux_handles[i].size < aux_bytes) { + // free storage if necessary and alloc again + if (aux_handles[i].size > 0) Storage::Get()->Free(aux_handles[i]); + // init aux storage + aux_handles[i] = Storage::Get()->Alloc(aux_bytes, ctx); + } + // init shape + aux_shapes[i] = shape; + } + /*! \brief destructor */ + ~Chunk() { + if (skip_delete_var) return; + bool skip_free = static_data || delay_alloc; + Storage::Handle h = this->shandle; + std::vector aux_h = this->aux_handles; + Engine::Get()->DeleteVariable([h, aux_h, skip_free](RunContext s) { + if (skip_free == false) { + Storage::Get()->Free(h); + for (size_t i = 0; i < aux_h.size(); i++) { + if (aux_h[i].size > 0) Storage::Get()->Free(aux_h[i]); + } + } + }, shandle.ctx, var); + } + }; // struct Chunk void SetTBlob() const { - tblob_.dptr_ = static_cast(ptr_->shandle.dptr) + byte_offset_; - tblob_.shape_ = shape_; + CHECK(ptr_ != nullptr); + TShape shape = shape_; + char *dptr = static_cast(ptr_->shandle.dptr); + auto stype = storage_type(); + if (stype == kDefaultStorage) { + dptr += byte_offset_; + } else if (stype == kCSRStorage || stype == kRowSparseStorage) { + shape = storage_shape(); + } else { + LOG(FATAL) << "unknown storage type " << stype; + } + tblob_.dptr_ = dptr; + tblob_.shape_ = shape; tblob_.type_flag_ = dtype_; tblob_.SetDLTensor(ptr_->shandle.ctx.dev_mask(), ptr_->shandle.ctx.dev_id); #if MKL_EXPERIMENTAL == 1 @@ -408,7 +785,7 @@ class NDArray { std::shared_ptr Mkl_mem_; #endif /*! \brief internal data of NDArray */ - std::shared_ptr ptr_; + std::shared_ptr ptr_{nullptr}; /*! \brief shape of current NDArray */ TShape shape_; /*! \brief byte offset in chunk */ @@ -425,7 +802,12 @@ class NDArray { * this situation. */ mutable TBlob tblob_; -}; +}; // class NDArray + +/*! + * \return the number of aux data used for given storage type + */ +size_t num_aux_data(NDArrayStorageType stype); /*! * \brief issue an copy operation from one NDArray to another @@ -435,12 +817,12 @@ class NDArray { * \param from the ndarray we want to copy data from * \param to the target ndarray * \param priority Priority of the action. + * \param alloc_output whether to allocate memory for the output ndarray * \note The function name explicitly marks the order of from and to * due to different possible convention carried by copy function. */ void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0); - /*! * \brief Perform elementwise sum over each data from source, store result into out. * \param source the ndarray we want to sum diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 316a90fe0841..cffca441e4b0 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -7,7 +7,6 @@ #ifndef MXNET_OP_ATTR_TYPES_H_ #define MXNET_OP_ATTR_TYPES_H_ - #include #include @@ -61,6 +60,17 @@ using FCompute = std::function& inputs, const std::vector& req, const std::vector& outputs)>; +/*! + * \brief Resiger an NDArray compute function for simple stateless forward only operator + * + * \note Register under "FComputeEx" and "FComputeEx" + * Dispatched only when operators process non-default storage inputs or outputs + */ +using FComputeEx = std::function& inputs, + const std::vector& req, + const std::vector& outputs)>; } // namespace mxnet #endif // MXNET_OP_ATTR_TYPES_H_ diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index 1b765233947d..e236a9cf313b 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -23,11 +23,11 @@ class Storage { /*! * \brief Pointer to the data. */ - void* dptr; + void* dptr{nullptr}; /*! * \brief Size of the storage. */ - size_t size; + size_t size{0}; /*! * \brief Context information about device and ID. */ diff --git a/mshadow b/mshadow index c037b06ddd81..8db65bd081c7 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit c037b06ddd810d39322cd056650f8b1f4763dd9d +Subproject commit 8db65bd081c7e243028ace93ef0acc9efc4383ba diff --git a/nnvm b/nnvm index 7796ac76ccea..2e3561500de9 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit 7796ac76ccea1fba31afc32056c83f6da38b6c57 +Subproject commit 2e3561500de99a0c173f3bc7b1a6c2b31435d6d9 diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index ff5f6cd6be7e..1e8c7731f3fb 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -8,6 +8,8 @@ from . import base from . import contrib from . import ndarray +from . import sparse_ndarray +from . import ndarray_utils from . import name # use mx.sym as short for symbol from . import symbol as sym @@ -18,6 +20,8 @@ from . import operator # use mx.nd as short for mx.ndarray from . import ndarray as nd +from . import sparse_ndarray as sparse_nd +from . import ndarray_utils as nd_utils # use mx.rnd as short for mx.random from . import random as rnd from . import random diff --git a/python/mxnet/contrib/autograd.py b/python/mxnet/contrib/autograd.py index e56361efdb1f..aa212c72fc9a 100644 --- a/python/mxnet/contrib/autograd.py +++ b/python/mxnet/contrib/autograd.py @@ -7,6 +7,7 @@ import functools from ..base import _LIB, check_call, string_types from ..base import mx_uint, NDArrayHandle, c_array +# pylint: disable= unused-import from ..ndarray import NDArray, zeros_like from ..symbol import _GRAD_REQ_MAP diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 6b9aab2de6f1..3991319ff13a 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -11,6 +11,7 @@ from .base import mx_uint, NDArrayHandle, ExecutorHandle from .base import check_call, c_array, py_str from .ndarray import NDArray +from .sparse_ndarray import _ndarray_cls from . import ndarray as nd # those functions are not used here, we just import them to keep backward compatibility @@ -90,7 +91,9 @@ def _get_outputs(self): handles = ctypes.POINTER(NDArrayHandle)() check_call(_LIB.MXExecutorOutputs(self.handle, ctypes.byref(out_size), ctypes.byref(handles))) - return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)] + num_output = out_size.value + outputs = [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(num_output)] + return outputs def forward(self, is_train=False, **kwargs): """Calculate the outputs specified by the bound symbol. diff --git a/python/mxnet/io.py b/python/mxnet/io.py index ec3c25f54d30..b728f50838a8 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -13,6 +13,7 @@ from .base import mx_real_t from .base import check_call, build_param_doc as _build_param_doc from .ndarray import NDArray +from .sparse_ndarray import _ndarray_cls from .ndarray import array from .ndarray import concatenate @@ -752,12 +753,12 @@ def iter_next(self): def getdata(self): hdl = NDArrayHandle() check_call(_LIB.MXDataIterGetData(self.handle, ctypes.byref(hdl))) - return NDArray(hdl, False) + return _ndarray_cls(hdl, False) def getlabel(self): hdl = NDArrayHandle() check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl))) - return NDArray(hdl, False) + return _ndarray_cls(hdl, False) def getindex(self): index_size = ctypes.c_uint64(0) diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index ab07421caffd..655f602856da 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -10,31 +10,39 @@ from .base import NDArrayHandle, KVStoreHandle from . import optimizer as opt -def _ctype_key_value(keys, vals): - """ - Returns ctype arrays for the key-value args. For internal use. - """ - if isinstance(keys, int): +def _ctype_str_key_value(keys, vals): + names = [] + if isinstance(keys, str): if isinstance(vals, NDArray): - return (c_array(ctypes.c_int, [keys]), + names.append(c_str(keys)) + return (c_array(ctypes.c_char_p, names), c_array(NDArrayHandle, [vals.handle])) else: for value in vals: assert(isinstance(value, NDArray)) - return (c_array(ctypes.c_int, [keys] * len(vals)), + return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)), c_array(NDArrayHandle, [value.handle for value in vals])) else: assert(len(keys) == len(vals)) for k in keys: - assert(isinstance(k, int)) + assert(isinstance(k, str)) c_keys = [] c_vals = [] for key, val in zip(keys, vals): - c_key_i, c_val_i = _ctype_key_value(key, val) + c_key_i, c_val_i = _ctype_str_key_value(key, val) c_keys += c_key_i c_vals += c_val_i - return (c_array(ctypes.c_int, c_keys), c_array(NDArrayHandle, c_vals)) + return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals)) +def _cast_to_str_keys(keys): + if isinstance(keys, str): + return keys + if isinstance(keys, int): + return str(keys) + str_keys = [] + for key in keys: + str_keys.append(str(key) if isinstance(key, int) else key) + return str_keys def _updater_wrapper(updater): """A wrapper for the user-defined handle.""" @@ -48,7 +56,7 @@ def updater_handle(key, lhs_handle, rhs_handle, _): class KVStore(object): """A key-value store for synchronization of values, over multiple devices.""" - def __init__(self, handle): + def __init__(self, handle, name2idx=None): """Initializes a new KVStore. Parameters @@ -58,6 +66,7 @@ def __init__(self, handle): """ assert isinstance(handle, KVStoreHandle) self.handle = handle + self.name2idx = name2idx if name2idx is not None else {} self._updater = None self._updater_func = None @@ -74,7 +83,7 @@ def init(self, key, value): Parameters ---------- - key : int or sequence of int + key : str or sequence of str The keys. value : NDArray or sequence of NDArray Values corresponding to the keys. @@ -84,20 +93,20 @@ def init(self, key, value): >>> # init a single key-value pair >>> shape = (2,3) >>> kv = mx.kv.create('local') - >>> kv.init(3, mx.nd.ones(shape)*2) + >>> kv.init('3', mx.nd.ones(shape)*2) >>> a = mx.nd.zeros(shape) - >>> kv.pull(3, out=a) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 2. 2. 2.] [ 2. 2. 2.]] >>> # init a list of key-value pairs - >>> keys = [5, 7, 9] + >>> keys = ['5', '7', '9'] >>> kv.init(keys, [mx.nd.ones(shape)]*len(keys)) """ - ckeys, cvals = _ctype_key_value(key, value) - check_call(_LIB.MXKVStoreInit( - self.handle, mx_uint(len(ckeys)), ckeys, cvals)) + key = _cast_to_str_keys(key) + ckeys, cvals = _ctype_str_key_value(key, value) + check_call(_LIB.MXKVStoreInitEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals)) def push(self, key, value, priority=0): """ Pushes a single or a sequence of key-value pairs into the store. @@ -110,7 +119,7 @@ def push(self, key, value, priority=0): Parameters ---------- - key : int or list of int + key : str or list of str Keys. value : NDArray or list of NDArray or list of list of NDArray @@ -124,8 +133,8 @@ def push(self, key, value, priority=0): Examples -------- >>> # push a single key-value pair - >>> kv.push(3, mx.nd.ones(shape)*8) - >>> kv.pull(3, out=a) # pull out the value + >>> kv.push('3', mx.nd.ones(shape)*8) + >>> kv.pull('3', out=a) # pull out the value >>> print a.asnumpy() [[ 8. 8. 8.] [ 8. 8. 8.]] @@ -133,8 +142,8 @@ def push(self, key, value, priority=0): >>> # aggregate the value and the push >>> gpus = [mx.gpu(i) for i in range(4)] >>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] - >>> kv.push(3, b) - >>> kv.pull(3, out=a) + >>> kv.push('3', b) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 4. 4. 4.] [ 4. 4. 4.]] @@ -156,11 +165,13 @@ def push(self, key, value, priority=0): [[ 4. 4. 4.] [ 4. 4. 4.]] """ - ckeys, cvals = _ctype_key_value(key, value) - check_call(_LIB.MXKVStorePush( + key = _cast_to_str_keys(key) + ckeys, cvals = _ctype_str_key_value(key, value) + check_call(_LIB.MXKVStorePushEx( self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) + def pull(self, key, out=None, priority=0): """ Pulls a single value or a sequence of values from the store. @@ -190,21 +201,21 @@ def pull(self, key, out=None, priority=0): -------- >>> # pull a single key-value pair >>> a = mx.nd.zeros(shape) - >>> kv.pull(3, out=a) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 2. 2. 2.] [ 2. 2. 2.]] >>> # pull into multiple devices >>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] - >>> kv.pull(3, out=b) + >>> kv.pull('3', out=b) >>> print b[1].asnumpy() [[ 2. 2. 2.] [ 2. 2. 2.]] >>> # pull a list of key-value pairs. >>> # On single device - >>> keys = [5, 7, 9] + >>> keys = ['5', '7', '9'] >>> b = [mx.nd.zeros(shape)]*len(keys) >>> kv.pull(keys, out=b) >>> print b[1].asnumpy() @@ -218,8 +229,9 @@ def pull(self, key, out=None, priority=0): [ 2. 2. 2.]] """ assert(out is not None) - ckeys, cvals = _ctype_key_value(key, out) - check_call(_LIB.MXKVStorePull( + key = _cast_to_str_keys(key) + ckeys, cvals = _ctype_str_key_value(key, out) + check_call(_LIB.MXKVStorePullEx( self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) @@ -348,13 +360,13 @@ def _set_updater(self, updater): ... print "update on key: %d" % key ... stored += input * 2 >>> kv._set_updater(update) - >>> kv.pull(3, out=a) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 4. 4. 4.] [ 4. 4. 4.]] - >>> kv.push(3, mx.nd.ones(shape)) + >>> kv.push('3', mx.nd.ones(shape)) update on key: 3 - >>> kv.pull(3, out=a) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 6. 6. 6.] [ 6. 6. 6.]] @@ -395,7 +407,7 @@ def _send_command_to_servers(self, head, body): check_call(_LIB.MXKVStoreSendCommmandToServers( self.handle, mx_uint(head), c_str(body))) -def create(name='local'): +def create(name='local', name2idx=None): """Creates a new KVStore. For single machine training, there are two commonly used types: @@ -435,4 +447,4 @@ def create(name='local'): handle = KVStoreHandle() check_call(_LIB.MXKVStoreCreate(c_str(name), ctypes.byref(handle))) - return KVStore(handle) + return KVStore(handle, name2idx=name2idx) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 189f301e91f7..c91ef5474601 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -37,7 +37,7 @@ 'eval_metric', 'locals']) -def _create_kvstore(kvstore, num_device, arg_params): +def _create_kvstore(kvstore, num_device, arg_params, name2idx=None): """Create kvstore This function select and create a proper kvstore if given the kvstore type. @@ -61,7 +61,7 @@ def _create_kvstore(kvstore, num_device, arg_params): # no need to use kv for single device and single machine kv = None else: - kv = kvs.create(kvstore) + kv = kvs.create(kvstore, name2idx=name2idx) if kvstore == 'local': # automatically select a proper local max_size = max(np.prod(param.shape) for param in @@ -80,38 +80,42 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_on_kvstore): """Initialize kvstore""" for idx, param_on_devs in enumerate(param_arrays): - kvstore.init(idx, arg_params[param_names[idx]]) + name = param_names[idx] + kvstore.init(name, arg_params[name]) if update_on_kvstore: - kvstore.pull(idx, param_on_devs, priority=-idx) + kvstore.pull(name, param_on_devs, priority=-idx) -def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore): +def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names): """Perform update of param_arrays from grad_arrays on kvstore.""" for index, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue + name = param_names[index] # push gradient, priority is negative index - kvstore.push(index, grad_list, priority=-index) + kvstore.push(name, grad_list, priority=-index) # pull back the weights - kvstore.pull(index, arg_list, priority=-index) + kvstore.pull(name, arg_list, priority=-index) def _update_params(param_arrays, grad_arrays, updater, num_device, - kvstore=None): + kvstore=None, param_names=None): """Perform update of param_arrays from grad_arrays not on kvstore.""" - for index, pair in enumerate(zip(param_arrays, grad_arrays)): + for i, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue + index = i if kvstore: + name = param_names[index] # push gradient, priority is negative index - kvstore.push(index, grad_list, priority=-index) + kvstore.push(name, grad_list, priority=-index) # pull back the sum gradients, to the same locations. - kvstore.pull(index, grad_list, priority=-index) + kvstore.pull(name, grad_list, priority=-index) for k, p in enumerate(zip(arg_list, grad_list)): # faked an index here, to make optimizer create diff # state for the same index but on diff devs, TODO(mli) - # use a better solution latter + # use a better solution later w, g = p updater(index*num_device+k, g, w) @@ -245,13 +249,14 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, if update_on_kvstore: _update_params_on_kvstore(executor_manager.param_arrays, executor_manager.grad_arrays, - kvstore) + kvstore, executor_manager.param_names) else: _update_params(executor_manager.param_arrays, executor_manager.grad_arrays, updater=updater, num_device=len(ctx), - kvstore=kvstore) + kvstore=kvstore, + param_names=executor_manager.param_names) if monitor is not None: monitor.toc_print() diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index fef5c507d7e8..8af84a307a82 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -7,6 +7,7 @@ import logging import warnings +import mxnet as mx from .. import context as ctx from .. import ndarray as nd from .. import optimizer as opt @@ -398,7 +399,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, else: assert self._arg_params is None and self._aux_params is None param_arrays = [ - nd.zeros(x[0].shape, dtype=x[0].dtype) + mx.nd.zeros(shape=x[0].shape, dtype=x[0].dtype, storage_type=x[0].storage_type) for x in self._exec_group.param_arrays ] self._arg_params = {name:arr for name, arr in zip(self._param_names, param_arrays)} @@ -412,7 +413,6 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, if shared_module is not None and shared_module.optimizer_initialized: self.borrow_optimizer(shared_module) - def reshape(self, data_shapes, label_shapes=None): """Reshapes the module for new input shapes. @@ -454,8 +454,12 @@ def init_optimizer(self, kvstore='local', optimizer='sgd', if self._params_dirty: self._sync_params_from_devices() + name2idx = {} + for idx, name in enumerate(self._exec_group.param_names): + name2idx[name] = idx + (kvstore, update_on_kvstore) = \ - _create_kvstore(kvstore, len(self._context), self._arg_params) + _create_kvstore(kvstore, len(self._context), self._arg_params, name2idx=name2idx) batch_size = self._exec_group.batch_size if kvstore and 'dist' in kvstore.type and '_sync' in kvstore.type: @@ -572,13 +576,14 @@ def update(self): if self._update_on_kvstore: _update_params_on_kvstore(self._exec_group.param_arrays, self._exec_group.grad_arrays, - self._kvstore) + self._kvstore, self._exec_group.param_names) else: _update_params(self._exec_group.param_arrays, self._exec_group.grad_arrays, updater=self._updater, num_device=len(self._context), - kvstore=self._kvstore) + kvstore=self._kvstore, + param_names=self._exec_group.param_names) def get_outputs(self, merge_multi_context=True): """Gets outputs of the previous forward computation. diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 8900843f5937..133e30ec6397 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -4,6 +4,7 @@ """NDArray API of MXNet.""" from __future__ import absolute_import from __future__ import division + try: from __builtin__ import slice as py_slice except ImportError: @@ -17,9 +18,9 @@ import operator import numpy as np -from .base import _LIB, string_types, numeric_types -from .base import c_array, py_str, c_str, mx_real_t, _Null # pylint: disable=unused-import -from .base import mx_uint, NDArrayHandle, check_call, OpHandle +from .base import _LIB, numeric_types, OpHandle, c_str +from .base import c_array, py_str, mx_real_t, _Null # pylint: disable=unused-import +from .base import mx_uint, NDArrayHandle, check_call from .base import ctypes2buffer from .context import Context from . import _ndarray_internal as _internal @@ -52,17 +53,29 @@ np.float64 : 1, np.float16 : 2, np.uint8 : 3, - np.int32 : 4 + np.int32 : 4, + np.int64 : 6 } - _DTYPE_MX_TO_NP = { 0 : np.float32, 1 : np.float64, 2 : np.float16, 3 : np.uint8, - 4 : np.int32 + 4 : np.int32, + 6 : np.int64 +} +_STORAGE_TYPE_ID_TO_STR = { + -1 : 'undefined', + 0 : 'default', + 1 : 'row_sparse', + 2 : 'csr', +} +_STORAGE_TYPE_STR_TO_ID = { + 'undefined' : -1, + 'default' : 0, + 'row_sparse' : 1, + 'csr' : 2, } -# pylint: enable= no-member def _new_empty_handle(): """Returns a new empty handle. @@ -106,6 +119,11 @@ def waitall(): """ check_call(_LIB.MXNDArrayWaitAll()) +def _storage_type(handle): + storage_type = ctypes.c_int(0) + check_call(_LIB.MXNDArrayGetStorageType(handle, ctypes.byref(storage_type))) + return _STORAGE_TYPE_ID_TO_STR[storage_type.value] + class NDArray(NDArrayBase): """An array object representing a multidimensional, homogeneous array of fixed-size items. @@ -113,12 +131,16 @@ class NDArray(NDArrayBase): """ __slots__ = [] # pylint: disable= no-member, undefined-variable + def __repr__(self): """Returns a string representation of the array.""" shape_info = 'x'.join(['%d' % x for x in self.shape]) return '<%s %s @%s>' % (self.__class__.__name__, shape_info, self.context) + def __reduce__(self): + return NDArray, (None,), self.__getstate__() + def __add__(self, other): """x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """ return add(self, other) @@ -629,7 +651,6 @@ def wait_to_read(self): """ check_call(_LIB.MXNDArrayWaitToRead(self.handle)) - @property def ndim(self): """Returns the number of dimensions of this array @@ -664,6 +685,7 @@ def shape(self): self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) return tuple(pdata[:ndim.value]) + @property def size(self): """Number of elements in the array. @@ -725,6 +747,10 @@ def dtype(self): self.handle, ctypes.byref(mx_dtype))) return _DTYPE_MX_TO_NP[mx_dtype.value] + @property + def storage_type(self): + return _storage_type(self.handle) + @property # pylint: disable= invalid-name, undefined-variable def T(self): @@ -949,6 +975,13 @@ def backward(self, out_grad=None, retain_graph=False): c_array(NDArrayHandle, ograd_handles), ctypes.c_int(retain_graph))) + def _to_csr(self): + # pylint: disable=undefined-variable + return cast_storage(self, storage_type='csr') + + def _to_rsp(self): + # pylint: disable=undefined-variable + return cast_storage(self, storage_type='row_sparse') def onehot_encode(indices, out): """One-hot encoding indices into matrix out. @@ -993,42 +1026,8 @@ def empty(shape, ctx=None, dtype=mx_real_t): ctx = Context.default_ctx return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype)) -def zeros(shape, ctx=None, dtype=mx_real_t, **kwargs): - """Returns a new array filled with all zeros, with the given shape and type. - - Parameters - ---------- - shape : int or tuple of int - The shape of the empty array. - ctx : Context, optional - An optional device context (default is the current default context). - dtype : str or numpy.dtype, optional - An optional value type (default is `float32`). - out : NDArray, optional - The output NDArray (default is `None`). - - Returns - ------- - NDArray - A created array - Examples - -------- - >>> mx.nd.zeros(1).asnumpy() - array([ 0.], dtype=float32) - >>> mx.nd.zeros((1,2), mx.gpu(0)) - - >>> mx.nd.zeros((1,2), mx.gpu(0), 'float16').asnumpy() - array([[ 0., 0.]], dtype=float16) - """ - # pylint: disable= unused-argument - if ctx is None: - ctx = Context.default_ctx - # pylint: disable= no-member, protected-access - return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) - # pylint: enable= no-member, protected-access - -def ones(shape, ctx=None, dtype=mx_real_t, **kwargs): +def ones(shape, ctx=None, dtype=None, **kwargs): """Returns a new array filled with all ones, with the given shape and type. Parameters @@ -1060,6 +1059,7 @@ def ones(shape, ctx=None, dtype=mx_real_t, **kwargs): # pylint: disable= unused-argument if ctx is None: ctx = Context.default_ctx + dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs) # pylint: enable= no-member, protected-access @@ -2101,89 +2101,6 @@ def negative(arr): """ return multiply(arr, -1.0) -def load(fname): - """Loads an array from file. - - See more details in ``save``. - - Parameters - ---------- - fname : str - The filename. - - Returns - ------- - list of NDArray or dict of str to NDArray - Loaded data. - """ - if not isinstance(fname, string_types): - raise TypeError('fname required to be a string') - out_size = mx_uint() - out_name_size = mx_uint() - handles = ctypes.POINTER(NDArrayHandle)() - names = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.MXNDArrayLoad(c_str(fname), - ctypes.byref(out_size), - ctypes.byref(handles), - ctypes.byref(out_name_size), - ctypes.byref(names))) - if out_name_size.value == 0: - return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)] - else: - assert out_name_size.value == out_size.value - return dict( - (py_str(names[i]), NDArray(NDArrayHandle(handles[i]))) for i in range(out_size.value)) - - -def save(fname, data): - """Saves a list of arrays or a dict of str->array to file. - - Examples of filenames: - - - ``/path/to/file`` - - ``s3://my-bucket/path/to/file`` (if compiled with AWS S3 supports) - - ``hdfs://path/to/file`` (if compiled with HDFS supports) - - Parameters - ---------- - fname : str - The filename. - data : list of ``NDArray` or dict of str to ``NDArray`` - The data to save. - - Examples - -------- - >>> x = mx.nd.zeros((2,3)) - >>> y = mx.nd.ones((1,4)) - >>> mx.nd.save('my_list', [x,y]) - >>> mx.nd.save('my_dict', {'x':x, 'y':y}) - >>> mx.nd.load('my_list') - [, ] - >>> mx.nd.load('my_dict') - {'y': , 'x': } - """ - handles = [] - if isinstance(data, dict): - keys = [] - for key, val in data.items(): - if not isinstance(key, string_types): - raise TypeError('save only accept dict str->NDArray or list of NDArray') - if not isinstance(val, NDArray): - raise TypeError('save only accept dict str->NDArray or list of NDArray') - keys.append(c_str(key)) - handles.append(val.handle) - keys = c_array(ctypes.c_char_p, keys) - else: - for val in data: - if not isinstance(val, NDArray): - raise TypeError('save only accept dict str->NDArray or list of NDArray') - handles.append(val.handle) - keys = None - check_call(_LIB.MXNDArraySave(c_str(fname), - mx_uint(len(handles)), - c_array(NDArrayHandle, handles), - keys)) - def concatenate(arrays, axis=0, always_copy=True): """DEPRECATED, use ``concat`` instead @@ -2408,9 +2325,8 @@ def %s(%s): # pylint: enable=too-many-locals, invalid-name -def _init_ndarray_module(ndarray_class, root_namespace): +def _init_ndarray_module(root_namespace): """List and add all the ndarray functions to current module.""" - _set_ndarray_class(ndarray_class) plist = ctypes.POINTER(ctypes.c_char_p)() size = ctypes.c_uint() @@ -2436,7 +2352,8 @@ def _init_ndarray_module(ndarray_class, root_namespace): else: setattr(module_obj, function.__name__, function) -_init_ndarray_module(NDArray, "mxnet") +# register backend operators in mx.nd +_init_ndarray_module("mxnet") # from .base import add_fileline_to_docstring # add_fileline_to_docstring(__name__) diff --git a/python/mxnet/ndarray_utils.py b/python/mxnet/ndarray_utils.py new file mode 100644 index 000000000000..5f8fa6c7bfb7 --- /dev/null +++ b/python/mxnet/ndarray_utils.py @@ -0,0 +1,198 @@ +# coding: utf-8 +"""Utility functions for NDArray and SparseNDArray.""" +import ctypes +import sys as _sys + +from mxnet import Context +from mxnet.base import mx_real_t, _LIB, check_call, py_str, c_str, string_types, mx_uint,\ + NDArrayHandle, c_array +from mxnet.ndarray import NDArray +from mxnet.sparse_ndarray import _STORAGE_AUX_TYPES, _new_alloc_handle, _ndarray_cls +from . import _ndarray_internal as _internal + + +def _zeros_ndarray(shape, ctx=None, dtype=None, **kwargs): + """Returns a new array filled with all zeros, with the given shape and type. + + Parameters + ---------- + shape : int or tuple of int + The shape of the empty array. + ctx : Context, optional + An optional device context (default is the current default context). + dtype : str or numpy.dtype, optional + An optional value type (default is `float32`). + out : NDArray, optional + The output NDArray (default is `None`). + + Returns + ------- + NDArray + A created array + + Examples + -------- + >>> mx.nd.zeros(1).asnumpy() + array([ 0.], dtype=float32) + >>> mx.nd.zeros((1,2), mx.gpu(0)) + + >>> mx.nd.zeros((1,2), mx.gpu(0), 'float16').asnumpy() + array([[ 0., 0.]], dtype=float16) + """ + # pylint: disable= unused-argument + if ctx is None: + ctx = Context.default_ctx + dtype = mx_real_t if dtype is None else dtype + # pylint: disable= no-member, protected-access + return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) + # pylint: enable= no-member, protected-access + + +def _zeros_sparse_ndarray(storage_type, shape, ctx=None, dtype=None, aux_types=None, **kwargs): + """Return a new array of given shape and type, filled with zeros. + + Parameters + ---------- + shape : int or tuple of int + The shape of the empty array + storage_type: string + The storage type of the empty array, such as 'row_sparse', 'csr', etc + ctx : Context, optional + An optional device context (default is the current default context) + dtype : str or numpy.dtype, optional + An optional value type (default is `float32`) + aux_types: list of numpy.dtype, optional + An optional type for the aux data for SparseNDArray (default values depends + on the storage type) + + Returns + ------- + SparseNDArray + A created array + Examples + -------- + >>> mx.sparse_nd.zeros('csr', (1,2), mx.gpu(0)) + + >>> mx.sparse_nd.zeros('row_sparse', (1,2), mx.gpu(0), 'float16').asnumpy() + array([[ 0., 0.]], dtype=float16) + """ + if storage_type == 'default': + return _zeros_ndarray(shape, ctx=ctx, dtype=dtype, **kwargs) + if ctx is None: + ctx = Context.default_ctx + dtype = mx_real_t if dtype is None else dtype + if aux_types is None: + if storage_type == 'row_sparse' or storage_type == 'csr': + aux_types = _STORAGE_AUX_TYPES[storage_type] + else: + raise Exception("unknown storage type") + assert(len(aux_types) == len(_STORAGE_AUX_TYPES[storage_type])) + out = _ndarray_cls(_new_alloc_handle(storage_type, shape, ctx, True, dtype, aux_types)) + return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out, **kwargs) + + +def zeros(shape, ctx=None, dtype=None, storage_type=None, aux_types=None, **kwargs): + if storage_type is None: + return _zeros_ndarray(shape, ctx, dtype, **kwargs) + else: + return _zeros_sparse_ndarray(storage_type, shape, ctx, dtype, aux_types, **kwargs) + + +def load(fname): + """Loads an array from file. + + See more details in ``save``. + + Parameters + ---------- + fname : str + The filename. + + Returns + ------- + list of NDArray or dict of str to NDArray + Loaded data. + """ + if not isinstance(fname, string_types): + raise TypeError('fname required to be a string') + out_size = mx_uint() + out_name_size = mx_uint() + handles = ctypes.POINTER(NDArrayHandle)() + names = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.MXNDArrayLoad(c_str(fname), + ctypes.byref(out_size), + ctypes.byref(handles), + ctypes.byref(out_name_size), + ctypes.byref(names))) + if out_name_size.value == 0: + return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)] + else: + assert out_name_size.value == out_size.value + return dict( + (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i]))) + for i in range(out_size.value)) + + +def save(fname, data): + """Saves a list of arrays or a dict of str->array to file. + + Examples of filenames: + + - ``/path/to/file`` + - ``s3://my-bucket/path/to/file`` (if compiled with AWS S3 supports) + - ``hdfs://path/to/file`` (if compiled with HDFS supports) + + Parameters + ---------- + fname : str + The filename. + data : list of ``NDArray` or dict of str to ``NDArray`` + The data to save. + + Examples + -------- + >>> x = mx.nd.zeros((2,3)) + >>> y = mx.nd.ones((1,4)) + >>> mx.nd.save('my_list', [x,y]) + >>> mx.nd.save('my_dict', {'x':x, 'y':y}) + >>> mx.nd.load('my_list') + [, ] + >>> mx.nd.load('my_dict') + {'y': , 'x': } + """ + handles = [] + if isinstance(data, dict): + keys = [] + for key, val in data.items(): + if not isinstance(key, string_types): + raise TypeError('save only accept dict str->NDArray or list of NDArray') + if not isinstance(val, NDArray): + raise TypeError('save only accept dict str->NDArray or list of NDArray') + keys.append(c_str(key)) + handles.append(val.handle) + keys = c_array(ctypes.c_char_p, keys) + else: + for val in data: + if not isinstance(val, NDArray): + raise TypeError('save only accept dict str->NDArray or list of NDArray') + handles.append(val.handle) + keys = None + check_call(_LIB.MXNDArraySave(c_str(fname), + mx_uint(len(handles)), + c_array(NDArrayHandle, handles), + keys)) + + +def _init_ndarray_module_frontend(function, root_namespace, module_name): + """Register front end functions defined in this file to mxnet.ndarray module. + The functions registered were originally defined in mxnet.ndarray. They were + moved here because they need to know SparseNDArray class, while it's not allowed + in ndarray.py since that would result in circular import.""" + module_obj = _sys.modules["%s.%s" % (root_namespace, module_name)] + setattr(module_obj, function.__name__, function) + + +# register the following front end functions in mx.nd +_init_ndarray_module_frontend(zeros, "mxnet", "ndarray") +_init_ndarray_module_frontend(load, "mxnet", "ndarray") +_init_ndarray_module_frontend(save, "mxnet", "ndarray") diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 1f7b1d3aed1b..10f9f06c11b3 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -2,8 +2,10 @@ import math import pickle import logging -from .ndarray import NDArray, zeros, clip, sqrt, sign +import mxnet as mx +from .ndarray import NDArray, clip, sqrt, sign from .ndarray import sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update +from .ndarray_utils import zeros from .random import normal @@ -332,7 +334,8 @@ def create_state(self, index, weight): if self.momentum == 0.0: return None else: - return zeros(weight.shape, weight.context, dtype=weight.dtype) + return mx.nd.zeros(shape=weight.shape, ctx=weight.context, + dtype=weight.dtype, storage_type=weight.storage_type) def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) @@ -510,8 +513,8 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, self.epsilon = epsilon def create_state(self, index, weight): - return (zeros(weight.shape, weight.context, dtype=weight.dtype), # mean - zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance + return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype), # mean + mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) @@ -616,11 +619,11 @@ def __init__(self, learning_rate=0.001, gamma1=0.9, gamma2=0.9, def create_state(self, index, weight): if self.centered: return ( - zeros(weight.shape, weight.context), # n - zeros(weight.shape, weight.context), # g - zeros(weight.shape, weight.context)) # delta + mx.nd.zeros(weight.shape, weight.context), # n + mx.nd.zeros(weight.shape, weight.context), # g + mx.nd.zeros(weight.shape, weight.context)) # delta else: - return (zeros(weight.shape, weight.context), ) # n + return (mx.nd.zeros(weight.shape, weight.context), ) # n def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) diff --git a/python/mxnet/sparse_ndarray.py b/python/mxnet/sparse_ndarray.py new file mode 100644 index 000000000000..5faddd979078 --- /dev/null +++ b/python/mxnet/sparse_ndarray.py @@ -0,0 +1,576 @@ +# coding: utf-8 +"""SparseNDArray API of mxnet.""" +from __future__ import absolute_import +from __future__ import division +try: + from __builtin__ import slice as py_slice +except ImportError: + from builtins import slice as py_slice + +import ctypes +import warnings + +import os as _os +import sys as _sys + +# import operator +import numpy as np +import mxnet as mx +from .base import _LIB, numeric_types +from .base import c_array, mx_real_t +from .base import mx_uint, NDArrayHandle, check_call +from .context import Context +from . import _ndarray_internal as _internal +from . import ndarray +from .ndarray import _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP +from .ndarray import _STORAGE_TYPE_STR_TO_ID +from .ndarray import NDArray, _storage_type + +# Use different verison of SymbolBase +# When possible, use cython to speedup part of computation. +# pylint: disable=unused-import +try: + if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: + from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class + elif _sys.version_info >= (3, 0): + from ._cy3.ndarray import NDArrayBase, _set_ndarray_class + else: + from ._cy2.ndarray import NDArrayBase, _set_ndarray_class +except ImportError: + if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: + raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") + from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class + +# pylint: enable=unused-import +_STORAGE_AUX_TYPES = { + 'row_sparse': [np.int64], + 'csr': [np.int64, np.int64] +} + + +def _new_alloc_handle(storage_type, shape, ctx, delay_alloc, dtype, aux_types, aux_shapes=None): + """Return a new handle with specified storage type, shape, dtype and context. + + Empty handle is only used to hold results + + Returns + ------- + handle + A new empty ndarray handle + """ + hdl = NDArrayHandle() + aux_type_ids = [int(_DTYPE_NP_TO_MX[np.dtype(aux_t).type]) for aux_t in aux_types] + aux_shapes = [(0,) for aux_t in aux_types] if aux_shapes is None else aux_shapes + aux_shape_lens = [len(aux_shape) for aux_shape in aux_shapes] + aux_shapes = sum(aux_shapes, ()) + num_aux = mx_uint(len(aux_types)) + check_call(_LIB.MXNDArrayCreateSparseEx( + ctypes.c_int(int(_STORAGE_TYPE_STR_TO_ID[storage_type])), + c_array(mx_uint, shape), + mx_uint(len(shape)), + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + ctypes.c_int(int(delay_alloc)), + ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])), + num_aux, + c_array(ctypes.c_int, aux_type_ids), + c_array(mx_uint, aux_shape_lens), + c_array(mx_uint, aux_shapes), + ctypes.byref(hdl))) + return hdl + + +class SparseNDArray(NDArray): + """An array object representing a multidimensional, homogeneous array of + fixed-size items, stored in sparse format. See CSRNDArray and RowSparseNDArray + for more details. + """ + def __iadd__(self, other): + raise NotImplementedError("SparseND doesn't support __iadd__") + + def __isub__(self, other): + raise NotImplementedError("SparseND doesn't support __isub__") + + def __imul__(self, other): + raise NotImplementedError("SparseND doesn't support __imul__") + + def __idiv__(self, other): + raise NotImplementedError("SparseND doesn't support __idiv__") + + def __itruediv__(self, other): + raise NotImplementedError("SparseND doesn't support __itruediv__") + + def __setitem__(self, key, value): + """x.__setitem__(i, y) <=> x[i]=y + + Set self[key] to value. Only slice [:] is supported. + + Parameters + ---------- + key : slice + The indexing key. + value : NDArray or numpy.ndarray + The value to set. + + Examples + -------- + >>> src = mx.sparse_nd.row_sparse(data, indices, (3,3)) + >>> src.asnumpy() + array([[ 1., 0., 2.], + [ 0., 0., 0.], + [ 4., 5., 6.]], dtype=float32) + >>> # assign SparseNDArray with same storage type + >>> x = mx.sparse_nd.zeros('row_sparse', (3,3)) + >>> x[:] = src + >>> x.asnumpy() + array([[ 1., 0., 2.], + [ 0., 0., 0.], + [ 4., 5., 6.]], dtype=float32) + >>> # assign NDArray to SparseNDArray + >>> x[:] = mx.nd.ones((3,3)) + >>> x.asnumpy() + array([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]], dtype=float32) + """ + if not self.writable: + raise ValueError('Failed to assign to a readonly NDArray') + if isinstance(key, py_slice): + if key.step is not None or key.start is not None or key.stop is not None: + raise ValueError('Assignment with slicing not supported in SparseNDArray.') + if isinstance(value, NDArray): + # avoid copying to itself + if value.handle is not self.handle: + value.copyto(self) + elif isinstance(value, numeric_types): + raise Exception("Assigning numeric types to SparseNDArray not supported yet.") + elif isinstance(value, (np.ndarray, np.generic)): + # TODO(haibin) Implement _sync_copyfrom for sparse ndarray to avoid an extra copy + warnings.warn('Assigning non-NDArray object to SparseNDArray is not efficient', + RuntimeWarning) + tmp = ndarray.array(value) + tmp.copyto(self) + else: + raise TypeError('type %s not supported' % str(type(value))) + else: + assert(isinstance(key, (int, tuple))) + raise Exception('SparseNDArray only supports [:] for assignment') + + def __getitem__(self, key): + """x.__getitem__(i) <=> x[i] + + Returns a sliced view of this array. + + Parameters + ---------- + key : int or slice + Indexing key. + + Examples + -------- + >>> x[:] = mx.nd.arange(0,6).reshape((2,3)) + >>> x.asnumpy() + array([[ 0., 1., 2.], + [ 3., 4., 5.]], dtype=float32) + >>> x[1:2].asnumpy() + array([[ 3., 4., 5.]], dtype=float32) + """ + stype = self.storage_type + if stype != 'csr': + raise Exception("__getitem__ for " + str(stype) + " not implemented yet") + if isinstance(key, int): + raise Exception("Not implemented yet") + if isinstance(key, py_slice): + if key.step is not None: + raise ValueError('NDArray only supports continuous slicing on axis 0') + if key.start is not None or key.stop is not None: + return self._slice(key.start, key.stop) + else: + return self + if isinstance(key, tuple): + raise ValueError('Multi-dimension indexing is not supported') + + def _sync_copyfrom(self, source_array): + raise Exception('Not implemented for SparseND yet!') + + def _slice(self, start, stop): + """Returns a read-only SparseNDArray slice that shares memory with current one. + To create a writable slice, please use ``mx.nd.slice`` instead. Currently only + `csr` storage type is supported. + + Parameters + ---------- + start : int + Starting index of slice. + stop : int + Finishing index of slice. + + Example + ---------- + >>> indptr = np.array([0, 2, 3, 6]) + >>> indices = np.array([0, 2, 2, 0, 1, 2]) + >>> data = np.array([1, 2, 3, 4, 5, 6]) + >>> a = mx.sparse_nd.csr(data, indptr, indices, (3, 3)) + >>> a.asnumpy() + array([[1, 0, 2], + [0, 0, 3], + [4, 5, 6]]) + + >>> a[1:2].asnumpy() + array([[0, 0, 3]]) + + """ + stype = self.storage_type + assert(stype == 'csr'), "_slice for " + str(stype) + " not implemented yet" + warnings.warn('slicing SparseNDArray is not efficient', RuntimeWarning) + handle = NDArrayHandle() + start = mx_uint(start) if start else mx_uint(0) + stop = mx_uint(stop) if stop else mx_uint(self.shape[0]) + check_call(_LIB.MXNDArraySlice( + self.handle, start, stop, ctypes.byref(handle))) + ret = _ndarray_cls(handle=handle, writable=False) + return ret + + def _at(self, idx): + raise Exception('at operator for SparseND is not supported.') + + def reshape(self, shape): + raise Exception('Not implemented for SparseND yet!') + + def broadcast_to(self, shape): + raise Exception('Not implemented for SparseND yet!') + + def _aux_type(self, i): + """Data-type of the array’s ith aux data. + + Returns + ------- + numpy.dtype + This SparseNDArray's aux data type. + """ + aux_type = ctypes.c_int() + check_call(_LIB.MXNDArrayGetAuxType(self.handle, i, ctypes.byref(aux_type))) + return _DTYPE_MX_TO_NP[aux_type.value] + + @property + def values(self): + """The values array of the SparseNDArray. This is a read-only view of the values array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's values array. + """ + return self._data() + + + @property + def _num_aux(self): + ''' The number of aux data used to help store the sparse ndarray. + ''' + return len(_STORAGE_AUX_TYPES[self.storage_type]) + + @property + # pylint: disable= invalid-name, undefined-variable + def T(self): + raise Exception('Transpose is not supported for SparseNDArray.') + + @property + def aux_types(self): + """The data types of the aux data for the SparseNDArray. + """ + aux_types = [] + num_aux = self._num_aux + for i in range(num_aux): + aux_types.append(self._aux_type(i)) + return aux_types + + def asnumpy(self): + """Return a dense ``numpy.ndarray`` object with value copied from this array + + """ + return self.todense().asnumpy() + + def astype(self, dtype): + """Returns a copy of the array after casting to a specified type. + Parameters + ---------- + dtype : numpy.dtype or str + The type of the returned array. + Examples + -------- + >>> x = mx.sparse_nd.zeros('row_sparse', (2,3), dtype='float32') + >>> y = x.astype('int32') + >>> y.dtype + + """ + res = mx.nd.zeros(shape=self.shape, ctx=self.context, + dtype=dtype, storage_type=self.storage_type) + self.copyto(res) + return res + + def copyto(self, other): + """Copies the value of this array to another array. + + If ``other`` is a ``NDArray`` object, then ``other.shape`` and + ``self.shape`` should be the same. This function copies the value from + ``self`` to ``other``. + + If ``other`` is a context, a new ``NDArray`` will be first created on + the target context, and the value of ``self`` is copied. + + Parameters + ---------- + other : NDArray or Context + The destination array or context. + + Returns + ------- + NDArray + The copied array. If ``other`` is an ``NDArray``, then the return value + and ``other`` will point to the same ``NDArray``. + """ + if isinstance(other, NDArray): + if other.handle is self.handle: + warnings.warn('You are attempting to copy an array to itself', RuntimeWarning) + return + return _internal._copyto(self, out=other) + elif isinstance(other, Context): + hret = _ndarray_cls(_new_alloc_handle(self.storage_type, self.shape, other, + True, self.dtype, self.aux_types)) + return _internal._copyto(self, out=hret) + else: + raise TypeError('copyto does not support type ' + str(type(other))) + + def todense(self): + return todense(self) + + def _aux_data(self, i, writable=False): + """ Get an NDArray referencing the ith aux data array associated with the SparseNDArray. + """ + self.wait_to_read() + hdl = NDArrayHandle() + check_call(_LIB.MXNDArrayGetAuxNDArray(self.handle, i, ctypes.byref(hdl))) + return NDArray(hdl, writable) + + def _data(self, writable=False): + """ Get an NDArray referencing the value array associated with the SparseNDArray. + """ + self.wait_to_read() + hdl = NDArrayHandle() + check_call(_LIB.MXNDArrayGetDataNDArray(self.handle, ctypes.byref(hdl))) + return NDArray(hdl, writable) + +# pylint: disable=abstract-method +class CSRNDArray(SparseNDArray): + """A CSRNDArray represents a NDArray as three separate arrays: `values`, + `indptr` and `indices`. It uses the standard CSR representation where the column indices for + row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored + in values[indptr[i]:indptr[i+1]]. + + """ + def __reduce__(self): + return CSRNDArray, (None,), super(CSRNDArray, self).__getstate__() + + @property + def indices(self): + """The indices array of the SparseNDArray. This is a read-only view of the indices array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's indices array. + """ + return self._aux_data(1) + + @property + def indptr(self): + """The indptr array of the SparseNDArray with `csr` storage type. + This is a read-only view of the indptr array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's indptr array. + """ + return self._aux_data(0) + +# pylint: disable=abstract-method +class RowSparseNDArray(SparseNDArray): + """A RowSparseNDArray is typically used to represent a subset of a larger + NDArray with `default` of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0. The values + in indices are the indices in the first dimension of the slices that have been extracted from + the larger NDArray. The indices are expected to be sorted in ascending order. + + The corresponding NDArray ``dense`` with `default` storage represented by a ``rsp`` + RowSparseNDArray + + ``dense[rsp.indices[i], :, :, :, ...] = rsp.values[i, :, :, :, ...]`` + + RowSparseNDArray is used principally in the definition of gradients for operations + that have sparse gradients (e.g. SparseEmbedding). + """ + def __reduce__(self): + return RowSparseNDArray, (None,), super(RowSparseNDArray, self).__getstate__() + + @property + def indices(self): + """The indices array of the SparseNDArray. This is a read-only view of the indices array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's indices array. + """ + return self._aux_data(0) + + +def _prepare_src_array(src, dtype, default_dtype): + if isinstance(src, NDArray): + dtype = src.dtype if dtype is None else dtype + else: + dtype = default_dtype if dtype is None else dtype + if not isinstance(src, np.ndarray): + try: + src = np.array(src, dtype=dtype) + except: + raise TypeError('values must be array like object') + return src, dtype + + +def csr(values, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, indices_type=None): + """Creates a 2D array with compressed sparse row format. + + Parameters + ---------- + values: array_like + An object exposing the array interface, with shape [nnz], where D0 is the number of + non-zero entries. + indptr: array_like + An object exposing the array interface, with shape [D0 + 1]. The first element in indptr + should always be zero. + indices: array_like + An object exposing the array interface, with shape [nnz]. + ctx : Context, optional + Device context (default is the current default context). + dtype : str or numpy.dtype, optional + The data type of the output array. The default dtype is ``values.dtype`` + if `values` is an `NDArray`, `float32` otherwise. + indptr_type: str or numpy.dtype, optional + The data type of the indices array. The default dtype is ``indptr.dtype`` + if `indptr` is an `NDArray`, `int32` otherwise. + indices_type: str or numpy.dtype, optional + The data type of the indices array. The default dtype is ``indices.dtype`` + if `indicies` is an `NDArray`, `int32` otherwise. + + Returns + ------- + CSRNDArray + A `CSRNDArray` with the `csr` storage representation. + """ + storage_type = 'csr' + # context + if ctx is None: + ctx = Context.default_ctx + # prepare src array and types + values, dtype = _prepare_src_array(values, dtype, mx_real_t) + indptr, indptr_type = _prepare_src_array(indptr, indptr_type, + _STORAGE_AUX_TYPES[storage_type][0]) + indices, indices_type = _prepare_src_array(indices, indices_type, + _STORAGE_AUX_TYPES[storage_type][1]) + # verify types + assert('int64' in str(indptr_type)), "expected int64 for indptr" + assert('int64' in str(indices_type)), "expected int64 for indices" + # verify shapes + aux_shapes = [indptr.shape, indices.shape] + assert(values.ndim == 1) + assert(indptr.ndim == 1) + assert(indices.ndim == 1) + assert(len(shape) == 2) + result = CSRNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype, + [indptr_type, indices_type], aux_shapes)) + # assign indptr, indices and values + values_ref = result._data(True) + indptr_ref = result._aux_data(0, True) + indices_ref = result._aux_data(1, True) + values_ref[:] = values + indptr_ref[:] = indptr + indices_ref[:] = indices + return result + + +def row_sparse(values, indices, shape, ctx=None, dtype=None, indices_type=None): + """Creates a row sparse array with a set of tensor slices at given indices. + + Parameters + ---------- + values: array_like + An object exposing the array interface, with shape [D0, D1, .. Dn], where D0 is + the number of rows with non-zeros entries. + indices: array_like + An object exposing the array interface, with shape [D0]. + ctx : Context, optional + Device context (default is the current default context). + dtype : str or numpy.dtype, optional + The data type of the output array. The default dtype is ``values.dtype`` + if `values` is an `NDArray`, `float32` otherwise. + indices_type: str or numpy.dtype, optional + The data type of the indices array. The default dtype is ``indices.dtype`` + if `indicies` is an `NDArray`, `int32` otherwise. + + Returns + ------- + RowSparseNDArray + An `RowSparseNDArray` with the `row_sparse` storage representation. + """ + storage_type = 'row_sparse' + # context + if ctx is None: + ctx = Context.default_ctx + # prepare src array and types + values, dtype = _prepare_src_array(values, dtype, mx_real_t) + indices, indices_type = _prepare_src_array(indices, indices_type, + _STORAGE_AUX_TYPES[storage_type][0]) + # verify types + assert('int64' in str(indices_type)), "expected int64 for indices" + # verify shapes + assert(values.ndim == len(shape)) + assert(indices.ndim == 1) + result = RowSparseNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype, + [indices_type], [indices.shape])) + # assign indices and values + values_ref = result._data(True) + indices_ref = result._aux_data(0, True) + values_ref[:] = values + indices_ref[:] = indices + return result + + +def todense(source): + """ Return a dense array representation of this SparseNDArray. + + Returns + ------- + NDArray + The dense array with default storage + """ + return ndarray.cast_storage(source, storage_type='default') + + +def _ndarray_cls(handle, writable=True): + stype = _storage_type(handle) + if stype == 'default': + return NDArray(handle, writable=writable) + elif stype == 'csr': + return CSRNDArray(handle, writable=writable) + elif stype == 'row_sparse': + return RowSparseNDArray(handle, writable=writable) + else: + raise Exception("unknown storage type") + + +_set_ndarray_class(_ndarray_cls) diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 14203e59862d..e752eb541648 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -19,6 +19,8 @@ from .context import Context, cpu from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP from .name import NameManager # pylint: disable=unused-import +from .ndarray import _STORAGE_TYPE_ID_TO_STR, _STORAGE_TYPE_STR_TO_ID +from .sparse_ndarray import _ndarray_cls from .executor import Executor from . import _symbol_internal as _internal from .attribute import AttrScope @@ -721,6 +723,89 @@ def list_auxiliary_states(self): self.handle, ctypes.byref(size), ctypes.byref(sarr))) return [py_str(sarr[i]) for i in range(size.value)] + def infer_storage_type(self, *args, **kwargs): + """Infer the storage type of outputs and arguments of given known types of arguments. + + User can either pass in the known types in positional way or keyword argument way. + Tuple of Nones is returned if there is not enough information passed in. + An error will be raised if there is inconsistency found in the known types passed in. + + Parameters + ---------- + *args : + Provide type of arguments in a positional way. + Unknown type can be marked as None + + **kwargs : + Provide keyword arguments of known types. + + Returns + ------- + arg_storage_types : list of numpy.dtype or None + List of types of arguments. + The order is in the same order as list_arguments() + out_storage_types : list of numpy.dtype or None + List of types of outputs. + The order is in the same order as list_outputs() + aux_storage_types : list of numpy.dtype or None + List of types of outputs. + The order is in the same order as list_auxiliary_states() + """ + # pylint: disable=too-many-locals + if len(args) != 0 and len(kwargs) != 0: + raise ValueError('Can only specify known argument \ + types either by positional or kwargs way.') + sdata = [] + if len(args) != 0: + keys = None + for s in args: + if s is not None: + if s not in _STORAGE_TYPE_STR_TO_ID or not isinstance(s, basestring): + raise TypeError('Argument need to be one of '+str(_STORAGE_TYPE_STR_TO_ID)) + sdata.append(_STORAGE_TYPE_STR_TO_ID[s]) + else: + sdata.append(_STORAGE_TYPE_STR_TO_ID['undefined']) + else: + keys = [] + for k, v in kwargs.items(): + if v in _STORAGE_TYPE_STR_TO_ID: + keys.append(c_str(k)) + sdata.append(_STORAGE_TYPE_STR_TO_ID[v]) + arg_storage_type_size = mx_uint() + arg_storage_type_data = ctypes.POINTER(ctypes.c_int)() + out_storage_type_size = mx_uint() + out_storage_type_data = ctypes.POINTER(ctypes.c_int)() + aux_storage_type_size = mx_uint() + aux_storage_type_data = ctypes.POINTER(ctypes.c_int)() + complete = ctypes.c_int() + check_call(_LIB.MXSymbolInferStorageType( + self.handle, + mx_uint(len(sdata)), + c_array(ctypes.c_char_p, keys), + c_array(ctypes.c_int, sdata), + ctypes.byref(arg_storage_type_size), + ctypes.byref(arg_storage_type_data), + ctypes.byref(out_storage_type_size), + ctypes.byref(out_storage_type_data), + ctypes.byref(aux_storage_type_size), + ctypes.byref(aux_storage_type_data), + ctypes.byref(complete))) + if complete.value != 0: + arg_storage_types = [ + _STORAGE_TYPE_ID_TO_STR[arg_storage_type_data[i]] \ + for i in range(arg_storage_type_size.value)] + out_storage_types = [ + _STORAGE_TYPE_ID_TO_STR[out_storage_type_data[i]] \ + for i in range(out_storage_type_size.value)] + aux_storage_types = [ + _STORAGE_TYPE_ID_TO_STR[aux_storage_type_data[i]] \ + for i in range(aux_storage_type_size.value)] + return (arg_storage_types, out_storage_types, aux_storage_types) + else: + return (None, None, None) + # pylint: enable=too-many-locals + + def infer_type(self, *args, **kwargs): """Infers the type of all arguments and all outputs, given the known types for some arguments. @@ -1160,8 +1245,9 @@ def _get_ndarray_inputs(arg_key, args, arg_names, allow_missing): raise TypeError('Only accept list of NDArrays or dict of str to NDArray') return c_array(NDArrayHandle, arg_handles), arg_arrays - def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, - shared_arg_names=None, shared_exec=None, shared_buffer=None, **kwargs): + def simple_bind(self, ctx, grad_req='write', type_dict=None, storage_type_dict=None, + group2ctx=None, shared_arg_names=None, shared_exec=None, + shared_buffer=None, **kwargs): """Bind current symbol to get an executor, allocate all the arguments needed. Allows specifying data types. @@ -1203,6 +1289,9 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, type_dict : Dict of str->numpy.dtype Input type dictionary, name->dtype + storage_type_dict : Dict of str->str + Input storage type dictionary, name->storage_type + group2ctx : Dict of string to mx.Context The dict mapping the `ctx_group` attribute to the context assignment. @@ -1217,7 +1306,8 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, shared_buffer : Dict of string to `NDArray` The dict mapping argument names to the `NDArray` that can be reused for initializing the current executor. This buffer will be checked for reuse if one argument name - of the current executor is not found in `shared_arg_names`. + of the current executor is not found in `shared_arg_names`. The `NDArray`s are + expected have default storage type. kwargs : Dict of str->shape Input shape dictionary, name->shape @@ -1227,6 +1317,7 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, executor : mxnet.Executor The generated executor """ + # data types num_provided_arg_types = 0 provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names provided_arg_type_data = ctypes.POINTER(mx_uint)() # provided types @@ -1242,6 +1333,22 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, provided_arg_type_names = c_array(ctypes.c_char_p, provided_arg_type_names) provided_arg_type_data = c_array(ctypes.c_int, provided_arg_type_data) + # storage types + num_provided_arg_stypes = 0 + # provided storage type argument names + provided_arg_stype_names = ctypes.POINTER(ctypes.c_char_p)() + provided_arg_stype_data = ctypes.POINTER(mx_uint)() # provided storage types + if storage_type_dict is not None: + provided_arg_stype_names = [] + provided_arg_stype_data = [] + for k, v in storage_type_dict.items(): + if v in _STORAGE_TYPE_STR_TO_ID: + provided_arg_stype_names.append(c_str(k)) + provided_arg_stype_data.append(ctypes.c_int(_STORAGE_TYPE_STR_TO_ID[v])) + num_provided_arg_stypes = mx_uint(len(provided_arg_stype_names)) + provided_arg_stype_names = c_array(ctypes.c_char_p, provided_arg_stype_names) + provided_arg_stype_data = c_array(ctypes.c_int, provided_arg_stype_data) + provided_arg_shape_data = [] # shape data # argument shape index in sdata, # e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg @@ -1315,6 +1422,8 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, shared_buffer_names = [] shared_buffer_handles = [] for k, v in shared_buffer.items(): + assert(v.storage_type == 'default'), \ + "shared_buffer is expected to only contain NDArrays with default storage" shared_buffer_names.append(c_str(k)) shared_buffer_handles.append(v.handle) shared_buffer_names = c_array(ctypes.c_char_p, shared_buffer_names) @@ -1354,6 +1463,9 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, num_provided_arg_types, provided_arg_type_names, provided_arg_type_data, + num_provided_arg_stypes, + provided_arg_stype_names, + provided_arg_stype_data, mx_uint(len(shared_arg_name_list)), c_array(ctypes.c_char_p, shared_arg_name_list), ctypes.byref(shared_buffer_len), @@ -1383,11 +1495,12 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, shared_buffer[k] = v # create in_args, arg_grads, and aux_states for the current executor - arg_arrays = [NDArray(NDArrayHandle(in_arg_handles[i])) for i in range(num_in_args.value)] - grad_arrays = [NDArray(NDArrayHandle(arg_grad_handles[i])) + arg_arrays = [_ndarray_cls(NDArrayHandle(in_arg_handles[i])) \ + for i in range(num_in_args.value)] + grad_arrays = [_ndarray_cls(NDArrayHandle(arg_grad_handles[i])) if arg_grad_handles[i] is not None else None for i in range(num_in_args.value)] - aux_arrays = [NDArray(NDArrayHandle(aux_state_handles[i])) + aux_arrays = [_ndarray_cls(NDArrayHandle(aux_state_handles[i])) for i in range(num_aux_states.value)] executor = Executor(exe_handle, self, ctx, grad_req, group2ctx) @@ -1638,7 +1751,8 @@ def reshape(self, shape): """ return reshape(self, shape=shape) -def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, init=None, **kwargs): +def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, + init=None, storage_type=None, **kwargs): """Creates a symbolic variable with specified name. Example usage: @@ -1692,6 +1806,8 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, ini if not isinstance(init, string_types): init = init.dumps() attr['__init__'] = init + if storage_type is not None: + attr['__storage_type__'] = str(_STORAGE_TYPE_STR_TO_ID[storage_type]) for k, v in kwargs.items(): if k.startswith('__') and k.endswith('__'): attr[k] = str(v) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 3ab44d0917a1..d860b531e520 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -10,17 +10,19 @@ import os import errno import logging +import scipy.sparse as sp import numpy as np import numpy.testing as npt -import mxnet as mx -from .context import Context -from .ndarray import array -from .symbol import Symbol +import numpy.random as rnd try: import requests except ImportError: # in rare cases requests may be not installed pass +import mxnet as mx +from .context import Context +from .ndarray import array, _STORAGE_TYPE_STR_TO_ID +from .symbol import Symbol _rng = np.random.RandomState(1234) @@ -66,6 +68,53 @@ def random_arrays(*shapes): return arrays +def random_sample(population, k): + """Return a k length list of the elements chosen from the population sequence.""" + assert 0 <= k <= len(population) + population_copy = population[:] + np.random.shuffle(population_copy) + return population_copy[0:k] + + +def rand_sparse_ndarray(shape, storage_type, density=None): + """Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np) """ + density = rnd.rand() if density is None else density + if storage_type == 'row_sparse': + # TODO(haibin) support high dim sparse ndarray + assert(len(shape) < 3) + prod = np.prod(shape) + num_cols = int(prod / shape[0]) + # sample index + idx_sample = rnd.rand(shape[0]) + indices = np.argwhere(idx_sample < density).flatten() + if indices.shape[0] == 0: + result = mx.nd.zeros(shape, storage_type='row_sparse') + return result, (np.array([], dtype='int64'), np.array([], dtype='int64')) + # generate random values + val = rnd.rand(indices.shape[0], num_cols) + arr = mx.sparse_nd.row_sparse(val, indices, shape, indices_type=np.int64) + return arr, (val, indices) + elif storage_type == 'csr': + assert(len(shape) == 2) + csr = sp.rand(shape[0], shape[1], density=density, format='csr') + result = mx.sparse_nd.csr(csr.data, csr.indptr, csr.indices, shape) + return result, (csr.indptr, csr.indices, csr.data) + else: + assert(False), "unknown storage type" + + +def rand_ndarray(shape, storage_type, density=None): + if storage_type == 'default': + arr = mx.nd.array(random_arrays(shape)) + else: + arr, _ = rand_sparse_ndarray(shape, storage_type, density=density) + return arr + + +def rand_shape_2d(dim0=10, dim1=10): + return rnd.randint(1, dim0), rnd.randint(1, dim1) + + def np_reduce(dat, axis, keepdims, numpy_reduce_func): """Compatible reduce for old version of NumPy. @@ -297,7 +346,8 @@ def _parse_location(sym, location, ctx): % (str(set(sym.list_arguments())), str(set(location.keys())))) else: location = {k: v for k, v in zip(sym.list_arguments(), location)} - location = {k: mx.nd.array(v, ctx=ctx) for k, v in location.items()} + location = {k: mx.nd.array(v, ctx=ctx) if isinstance(v, np.ndarray) \ + else v for k, v in location.items()} return location @@ -418,7 +468,8 @@ def numeric_grad(executor, location, aux_states=None, eps=1e-4, use_forward_trai def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rtol=1e-2, - atol=None, grad_nodes=None, use_forward_train=True, ctx=None): + atol=None, grad_nodes=None, use_forward_train=True, ctx=None, + grad_stype_dict=None): """Verify an operation by checking backward pass via finite difference method. Based on Theano's `theano.gradient.verify_grad` [1] @@ -435,7 +486,7 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto - if type is dict of str -> numpy.ndarray maps the name of arguments to the corresponding numpy.ndarray. *In either case, value of all the arguments must be provided.* - aux_states : ist or tuple or dict, optional + aux_states : list or tuple or dict, optional The auxiliary states required when generating the executor for the symbol. numeric_eps : float, optional Delta for the finite difference method that approximates the gradient. @@ -447,6 +498,8 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto Whether to use is_train=True when computing the finite-difference. ctx : Context, optional Check the gradient computation on the specified device. + grad_stype_dict : dict of str->str, optional + Storage type dictionary for gradient ndarrays. References --------- ..[1] https://github.com/Theano/Theano/blob/master/theano/gradient.py @@ -470,7 +523,7 @@ def random_projection(shape): location_npy = {k:v.asnumpy() for k, v in location.items()} aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx) if aux_states is not None: - aux_states_npy = {k:v.asnumpy() for k, v in aux_states.items()} + aux_states_npy = {k: v.asnumpy() for k, v in aux_states.items()} else: aux_states_npy = None if grad_nodes is None: @@ -497,6 +550,11 @@ def random_projection(shape): + [("__random_proj", _rng.normal(0, 0.01, size=out_shape[0]))]) args_grad = {k: mx.nd.array(v, ctx=ctx) for k, v in args_grad_npy.items()} + if grad_stype_dict is not None: + assert isinstance(grad_stype_dict, dict), "grad_stype_dict must be a dict" + for k, v in grad_stype_dict.items(): + if k in args_grad and v in _STORAGE_TYPE_STR_TO_ID and v != 'default': + args_grad[k] = mx.nd.cast_storage(args_grad[k], storage_type=v) executor = out.bind(ctx, grad_req=grad_req, args=location, args_grad=args_grad, aux_states=aux_states) @@ -588,8 +646,8 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, g[:] = 0 executor.forward(is_train=False) - outputs = [x.asnumpy() for x in executor.outputs] + outputs = [x.asnumpy() for x in executor.outputs] for output_name, expect, output in zip(sym.list_outputs(), expected, outputs): assert_almost_equal(expect, output, rtol, atol, ("EXPECTED_%s"%output_name, "FORWARD_%s"%output_name)) @@ -657,14 +715,29 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= if isinstance(expected, (list, tuple)): expected = {k:v for k, v in zip(sym.list_arguments(), expected)} args_grad_npy = {k:_rng.normal(size=v.shape) for k, v in expected.items()} - args_grad_data = {k: mx.nd.array(v, ctx=ctx) for k, v in args_grad_npy.items()} + # args_grad_data should be casted to storage type if hinted + # TODO(haibin) this is a temporary solution for testing. remove later + attrs = sym.attr_dict() + args_grad_data = {} + for k, v in args_grad_npy.items(): + attr = attrs.get(k, {}) + grad_stype = attr.get('grad_stype_hint', None) + nd = mx.nd.array(v, ctx=ctx) + if grad_stype is not None: + out = mx.nd.cast_storage(nd, storage_type=grad_stype) + args_grad_data[k] = out + else: + args_grad_data[k] = nd + if isinstance(grad_req, str): grad_req = {k:grad_req for k in sym.list_arguments()} elif isinstance(grad_req, (list, tuple)): grad_req = {k:v for k, v in zip(sym.list_arguments(), grad_req)} - executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states) + executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, + aux_states=aux_states, grad_req=grad_req) executor.forward(is_train=True) + if isinstance(out_grads, (tuple, list)): out_grads = [mx.nd.array(v, ctx=ctx) for v in out_grads] elif isinstance(out_grads, (dict)): diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 9d60c8615027..f2472f93371e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -154,6 +154,39 @@ int MXNDArrayCreateEx(const mx_uint *shape, API_END(); } +int MXNDArrayCreateSparseEx(int storage_type, + const mx_uint *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + mx_uint num_aux, + int *aux_type, + mx_uint *aux_ndims, + const mx_uint *aux_shape, + NDArrayHandle *out) { + API_BEGIN(); + std::vector aux_types; + std::vector aux_shapes; + auto shape_start = aux_shape; + for (size_t i = 0; i < num_aux; i++) { + // types + aux_types.push_back(aux_type[i]); + // shapes + aux_shapes.emplace_back(shape_start, shape_start + aux_ndims[i]); + shape_start += aux_ndims[i]; + } + *out = new NDArray( + NDArrayStorageType(storage_type), + TShape(shape, shape + ndim), + Context::Create(static_cast(dev_type), dev_id), + delay_alloc != 0, + dtype, aux_types, aux_shapes); + API_END(); +} + + int MXNDArrayLoadFromRawBytes(const void *buf, size_t size, NDArrayHandle *out) { @@ -333,6 +366,18 @@ MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle, API_END_HANDLE_ERROR(delete ptr); } +int MXNDArrayGetStorageType(NDArrayHandle handle, + int *out_storage_type) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + if (!arr->is_none()) { + *out_storage_type = arr->storage_type(); + } else { + *out_storage_type = kUndefinedStorage; + } + API_END(); +} + int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata) { @@ -382,6 +427,32 @@ int MXNDArrayGetDType(NDArrayHandle handle, API_END(); } +int MXNDArrayGetAuxType(NDArrayHandle handle, + mx_uint i, + int *out_type) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out_type = arr->aux_type(i); + API_END(); +} + +int MXNDArrayGetAuxNDArray(NDArrayHandle handle, + mx_uint i, + NDArrayHandle *out) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out = new NDArray(arr->aux_ndarray(i)); + API_END(); +} + +int MXNDArrayGetDataNDArray(NDArrayHandle handle, + NDArrayHandle *out) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out = new NDArray(arr->data_ndarray()); + API_END(); +} + int MXNDArrayGetContext(NDArrayHandle handle, int *out_dev_type, int *out_dev_id) { @@ -625,6 +696,21 @@ int MXKVStoreInit(KVStoreHandle handle, API_END(); } +int MXKVStoreInitEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals) { + API_BEGIN(); + std::vector v_keys(num); + std::vector v_vals(num); + for (mx_uint i = 0; i < num; ++i) { + v_keys[i] = keys[i]; + v_vals[i] = *static_cast(vals[i]); + } + static_cast(handle)->Init(v_keys, v_vals); + API_END(); +} + int MXKVStorePush(KVStoreHandle handle, mx_uint num, const int* keys, @@ -641,6 +727,22 @@ int MXKVStorePush(KVStoreHandle handle, API_END(); } +int MXKVStorePushEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority) { + API_BEGIN(); + std::vector v_keys(num); + std::vector v_vals(num); + for (mx_uint i = 0; i < num; ++i) { + v_keys[i] = keys[i]; + v_vals[i] = *static_cast(vals[i]); + } + static_cast(handle)->Push(v_keys, v_vals, priority); + API_END(); +} + int MXKVStorePull(KVStoreHandle handle, mx_uint num, const int* keys, @@ -657,6 +759,22 @@ int MXKVStorePull(KVStoreHandle handle, API_END(); } +int MXKVStorePullEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority) { + API_BEGIN(); + std::vector v_keys(num); + std::vector v_vals(num); + for (mx_uint i = 0; i < num; ++i) { + v_keys[i] = keys[i]; + v_vals[i] = static_cast(vals[i]); + } + static_cast(handle)->Pull(v_keys, v_vals, priority); + API_END(); +} + int MXKVStoreSetUpdater(KVStoreHandle handle, MXKVStoreUpdater updater, void* updater_handle) { diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index d8857f80635d..f2cad238a71b 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -58,6 +58,8 @@ struct MXAPIThreadLocalEntry { std::vector arg_shapes, out_shapes, aux_shapes; /*! \brief result holder for returning type flags */ std::vector arg_types, out_types, aux_types; + /*! \brief result holder for returning storage types */ + std::vector arg_storage_types, out_storage_types, aux_storage_types; /*! \brief result holder for returning shape dimensions */ std::vector arg_shape_ndim, out_shape_ndim, aux_shape_ndim; /*! \brief result holder for returning shape pointer */ diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index ca49402ecf7e..d9beb410e929 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -173,6 +173,9 @@ int MXExecutorBindEX(SymbolHandle symbol_handle, * \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes * \param provided_arg_dtype_names argument name list of provided dtypes * \param provided_arg_dtypes data of provided dtypes + * \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types + * \param provided_arg_stype_names argument name list of provided storage types + * \param provided_arg_stypes data of provided storage types * \param num_shared_arg_names number of parameter names passed from _bind_ith_exec * \param shared_arg_name_list parameter name list passed from _bind_ith_exec * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec @@ -205,6 +208,9 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, const mx_uint num_provided_arg_dtypes, const char** provided_arg_dtype_names, const int* provided_arg_dtypes, + const mx_uint num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, const mx_uint num_shared_arg_names, const char** shared_arg_name_list, int* shared_buffer_len, @@ -229,7 +235,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, // attr_dict for setting up type_dict and arg/aux ctx std::unordered_map> attr_dict; - if (nullptr == provided_arg_dtypes || nullptr != g2c_keys) { + if (nullptr == provided_arg_dtypes || nullptr != g2c_keys || nullptr == provided_arg_stypes) { std::vector> attrs = sym->ListAttrsRecursive(); attr_dict.reserve(attrs.size()); @@ -255,6 +261,23 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, } } + // setup arg_stype_map + std::unordered_map arg_stype_map; + if (nullptr == provided_arg_stypes) { // use attr_dict + for (const auto& arg_name : in_arg_names) { + const auto it = attr_dict.find(arg_name); + if (it == attr_dict.end() || !it->second.count("__storage_type__")) { + arg_stype_map[arg_name] = kDefaultStorage; + } + } + } else { // use user input type_dict + // create stype map for in_args and aux_states + arg_stype_map.reserve(num_provided_arg_stypes); + for (mx_uint i = 0; i < num_provided_arg_stypes; ++i) { + arg_stype_map[provided_arg_stype_names[i]] = provided_arg_stypes[i]; + } + } + // create default ctx Context ctx = Context::Create(static_cast(dev_type), dev_id); // create ctx map @@ -395,9 +418,10 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, std::vector aux_state_vec; *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, - aux_state_ctx_vec, arg_shape_map, arg_dtype_map, grad_req_type_vec, - shared_arg_name_set, &in_arg_vec, &arg_grad_vec, &aux_state_vec, - use_shared_buffer? &shared_buffer_map : nullptr, + aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map, + grad_req_type_vec, shared_arg_name_set, &in_arg_vec, + &arg_grad_vec, &aux_state_vec, + use_shared_buffer ? &shared_buffer_map : nullptr, reinterpret_cast(shared_exec_handle)); // copy ndarray ptrs to ret->handles so that front end diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 0be1d3574dd9..8d190597ab0b 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2016 by Contributors - * \file c_api_symbolic.cc + * \file c_api_ndarray.cc * \brief C API of mxnet */ @@ -16,6 +16,8 @@ #include "../common/utils.h" #include "../ndarray/autograd.h" +#define IMPERATIVE_EXEC_DEBUG 0 + using namespace mxnet; using mxnet::autograd::AutogradRuntime; @@ -122,16 +124,18 @@ void SetContext(Context* p_ctx, ctx = Context::CPU(); } } - +// Set the shape, dtype and storage type void SetShapeType(const nnvm::Op* op, const nnvm::NodeAttrs& attrs, const Context& ctx, const std::vector& ndinputs, const int& infered_num_outputs, - std::vector* p_ndoutputs) { + std::vector* p_ndoutputs, + int* dispatch_stype) { std::vector& ndoutputs = *p_ndoutputs; static auto& infershape = nnvm::Op::GetAttr("FInferShape"); static auto& infertype = nnvm::Op::GetAttr("FInferType"); + static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType"); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); // infer shape std::vector& in_shapes = ret->arg_shapes; @@ -167,9 +171,41 @@ void SetShapeType(const nnvm::Op* op, CHECK(infertype[op](attrs, &in_types, &out_types)); CHECK_EQ(out_types.size(), static_cast(infered_num_outputs)); + // infer storage type + auto& in_storage_types = ret->arg_storage_types; + auto& out_storage_types = ret->out_storage_types; + in_storage_types.clear(); + out_storage_types.clear(); + + for (auto& i : ndinputs) { + in_storage_types.push_back(i.storage_type()); + } + for (auto& i : ndoutputs) { + out_storage_types.push_back(i.storage_type()); + } + if (inferstorage.count(op)) { + CHECK(inferstorage[op](attrs, &in_storage_types, &out_storage_types)); + CHECK_EQ(out_storage_types.size(), static_cast(infered_num_outputs)); + } else { +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "FInferStorageType not present."; +#endif + } + + bool contains_non_default = common::ContainsNonDefaultStorage(in_storage_types); + contains_non_default |= common::ContainsNonDefaultStorage(out_storage_types); + int kNonDefaultStorage = -2; + *dispatch_stype = contains_non_default ? kNonDefaultStorage : kDefaultStorage; + for (int i = 0; i < infered_num_outputs; ++i) { + NDArrayStorageType storage_type = static_cast(out_storage_types[i]); if (ndoutputs[i].is_none()) { - ndoutputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]); + // If failed to infer the storage type, assume the output storage is dense + if (storage_type == kDefaultStorage || out_storage_types[i] == kUndefinedStorage) { + ndoutputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]); + } else { + ndoutputs[i] = NDArray(storage_type, out_shapes[i], ctx, true, out_types[i]); + } } else { CHECK_EQ(ndoutputs[i].shape(), out_shapes[i]) << i << "th output has invalid shape. " @@ -216,23 +252,20 @@ void SetDependency(std::vector *p_read_vars, } CHECK_LE(ntmp, 1) << "Only support 1 temp space request"; } - - for (auto& i : ndinputs) { - read_vars.push_back(i.var()); - } - for (auto& i : ndoutputs) { - write_vars.push_back(i.var()); - } + for (auto& i : ndinputs) read_vars.emplace_back(i.var()); + for (auto& i : ndoutputs) write_vars.emplace_back(i.var()); if (mutate.count(op)) { auxidx = mutate[op](attrs); std::sort(auxidx.begin(), auxidx.end()); - for (auto & i : auxidx) { - write_vars.push_back(ndinputs[i].var()); + for (auto& i : auxidx) { + auto var = ndinputs[i].var(); + write_vars.push_back(var); } } Engine::Get()->DeduplicateVarHandle(&read_vars, &write_vars); } + void PushFCompute(const FCompute& fn, const nnvm::Op* op, const nnvm::NodeAttrs& attrs, @@ -242,23 +275,61 @@ void PushFCompute(const FCompute& fn, const std::vector& requested, const std::vector& ndinputs, const std::vector& ndoutputs) { + using namespace common; bool is_train = AutogradRuntime::Get()->IsTraining(); Engine::Get()->PushAsync( [ctx, attrs, fn, ndinputs, ndoutputs, requested, is_train]( RunContext rctx, engine::CallbackOnComplete on_complete) { std::vector input_blobs, output_blobs; - for (auto& i : ndinputs) { - input_blobs.push_back(i.data()); - } - for (auto& i : ndoutputs) { - output_blobs.push_back(i.data()); - } + std::vector temp_in; + std::vector temp_out; OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested}; - std::vector req(output_blobs.size(), kWriteTo); - fn(attrs, opctx, input_blobs, req, output_blobs); + if (ctx.dev_mask() == gpu::kDevMask) { +#if MXNET_USE_CUDA + GetDefaultBlobs(ndinputs, &input_blobs, &temp_in, opctx); + GetDefaultBlobs(ndoutputs, &output_blobs, &temp_out, opctx); + std::vector req(output_blobs.size(), kWriteTo); + fn(attrs, opctx, input_blobs, req, output_blobs); + // cast to original storage type, if necessary + CastNonDefaultStorage(ndoutputs, temp_out, opctx); + rctx.get_stream()->Wait(); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + GetDefaultBlobs(ndinputs, &input_blobs, &temp_in, opctx); + GetDefaultBlobs(ndoutputs, &output_blobs, &temp_out, opctx); + std::vector req(output_blobs.size(), kWriteTo); + fn(attrs, opctx, input_blobs, req, output_blobs); + CastNonDefaultStorage(ndoutputs, temp_out, opctx); + } + on_complete(); + }, ctx, read_vars, write_vars, FnProperty::kNormal, + 0, PROFILER_MESSAGE(op->name.c_str())); +} + +void PushFComputeEx(const FComputeEx& fn, + const nnvm::Op* op, + const nnvm::NodeAttrs& attrs, + const Context& ctx, + const std::vector& read_vars, + const std::vector& write_vars, + const std::vector& requested, + const std::vector& ndinputs, + const std::vector& ndoutputs) { + Engine::Get()->PushAsync( + [ctx, attrs, fn, ndinputs, ndoutputs, requested]( + RunContext rctx, + engine::CallbackOnComplete on_complete) { + std::vector input_blobs, output_blobs; + OpContext opctx{false, rctx, + engine::CallbackOnComplete(), + requested}; + std::vector req(ndoutputs.size(), kWriteTo); + fn(attrs, opctx, ndinputs, req, ndoutputs); if (ctx.dev_mask() == gpu::kDevMask) { rctx.get_stream()->Wait(); } @@ -327,8 +398,6 @@ void ImperativeInvokeImpl(const nnvm::NodeAttrs& attrs, NDArrayHandle *inputs, int *num_outputs, NDArrayHandle **outputs) { - static auto& fcpu = nnvm::Op::GetAttr("FCompute"); - static auto& fgpu = nnvm::Op::GetAttr("FCompute"); static auto& ndfunc = nnvm::Op::GetAttr("FNDArrayFunction"); static auto& createop = nnvm::Op::GetAttr("FCreateLayerOp"); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); @@ -337,20 +406,23 @@ void ImperativeInvokeImpl(const nnvm::NodeAttrs& attrs, int infered_num_outputs; int num_visible_outputs; - SetNumOutputs(op, attrs, num_inputs, - &infered_num_outputs, &num_visible_outputs); + SetNumOutputs(op, attrs, num_inputs, &infered_num_outputs, &num_visible_outputs); std::vector ndinputs, ndoutputs; SetNDInputsOutputs(op, &ndinputs, &ndoutputs, num_inputs, inputs, - num_outputs, infered_num_outputs, num_visible_outputs, outarray); + num_outputs, infered_num_outputs, num_visible_outputs, outarray); if (ndfunc.count(op)) { ndfunc[op](attrs, ndinputs, &ndoutputs); +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "NDArray function executed."; +#endif } else { // TODO(piiswrong): infer ctx Context ctx; + int storage_type; SetContext(&ctx, attrs, num_inputs, ndinputs, infered_num_outputs, ndoutputs); - SetShapeType(op, attrs, ctx, ndinputs, infered_num_outputs, &ndoutputs); + SetShapeType(op, attrs, ctx, ndinputs, infered_num_outputs, &ndoutputs, &storage_type); std::vector read_vars, write_vars; std::vector requested; @@ -358,20 +430,24 @@ void ImperativeInvokeImpl(const nnvm::NodeAttrs& attrs, SetDependency(&read_vars, &write_vars, &requested, &auxidx, op, attrs, ctx, ndinputs, ndoutputs); - FCompute fn; - if (ctx.dev_mask() == cpu::kDevMask && fcpu.count(op)) { - fn = fcpu[op]; - } else if (ctx.dev_mask() == gpu::kDevMask && fgpu.count(op)) { - fn = fgpu[op]; - } - - if (fn) { + FCompute fn = common::GetFCompute(op, ctx); + FComputeEx fcomp_ex = common::GetFComputeEx(op, ctx, storage_type); + if (fcomp_ex) { + PushFComputeEx(fcomp_ex, op, attrs, ctx, read_vars, write_vars, requested, + ndinputs, ndoutputs); +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "FComputeEx executed."; +#endif + } else if (fn) { if (AutogradRuntime::Get()->IsTraining()) { AutogradRuntime::Get()->RecordImperativeFCompute(op, attrs, &ndinputs, &ndoutputs); } PushFCompute(fn, op, attrs, ctx, read_vars, write_vars, requested, ndinputs, ndoutputs); +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "FCompute executed."; +#endif } else if (createop.count(op)) { std::shared_ptr opr( createop[op](attrs, ctx, ret->arg_shapes, ret->arg_types)); @@ -381,11 +457,14 @@ void ImperativeInvokeImpl(const nnvm::NodeAttrs& attrs, } PushOperator(opr, op, attrs, ctx, read_vars, write_vars, requested, auxidx, ndinputs, ndoutputs); +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "CreateOp executed."; +#endif } else { LOG(FATAL) << "Operator " << op->name << " cannot be run; requires at least one of" - << " FCompute, NDArrayFunction, FCreateOperator be registered"; + << " FCompute, FComputeEx NDArrayFunction, FCreateOperator be registered"; } } diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index cad9e604df60..f4737fa8b3e2 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -512,6 +512,58 @@ int MXSymbolInferShapePartial(SymbolHandle sym, &succ); } +// TODO(haibin) refactor with infer_type +int MXSymbolInferStorageType(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const int *arg_storage_type_data, + mx_uint *in_storage_type_size, + const int **in_storage_type_data, + mx_uint *out_storage_type_size, + const int **out_storage_type_data, + mx_uint *aux_storage_type_size, + const int **aux_storage_type_data, + int *complete) { + nnvm::Symbol *s = static_cast(sym); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + nnvm::Graph g = Symbol2Graph(*s); + nnvm::StorageTypeVector arg_storage_types(g.indexed_graph().input_nodes().size(), + kUndefinedStorage); + if (keys == nullptr && num_args != 0) { + std::vector read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); + CHECK_LE(num_args, read_only_args.size()); + for (mx_uint i = 0; i < num_args; ++i) { + arg_storage_types[read_only_args[i]] = arg_storage_type_data[i]; + } + } else { + std::unordered_map kwargs; + for (mx_uint i = 0; i < num_args; ++i) { + kwargs[keys[i]] = arg_storage_type_data[i]; + } + mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_storage_types, "InferStorageType"); + } + + g = nnvm::pass::InferStorageType(std::move(g), arg_storage_types, "__storage_type__"); + // copy back + CopyAttr(g.indexed_graph(), g.GetAttr("storage_type"), + &(ret->arg_storage_types), &(ret->out_storage_types), &(ret->aux_storage_types)); + + *in_storage_type_size = static_cast(ret->arg_storage_types.size()); + *in_storage_type_data = dmlc::BeginPtr(ret->arg_storage_types); + *out_storage_type_size = static_cast(ret->out_storage_types.size()); + *out_storage_type_data = dmlc::BeginPtr(ret->out_storage_types); + *in_storage_type_size = static_cast(ret->arg_storage_types.size()); + *in_storage_type_data = dmlc::BeginPtr(ret->arg_storage_types); + *out_storage_type_size = static_cast(ret->out_storage_types.size()); + *out_storage_type_data = dmlc::BeginPtr(ret->out_storage_types); + *aux_storage_type_size = static_cast(ret->aux_storage_types.size()); + *aux_storage_type_data = dmlc::BeginPtr(ret->aux_storage_types); + *complete = (g.GetAttr("storage_type_num_unknown_nodes") == 0); + API_END(); +} + + int MXSymbolInferType(SymbolHandle sym, mx_uint num_args, const char** keys, diff --git a/src/common/utils.cc b/src/common/utils.cc new file mode 100644 index 000000000000..5bfb959fdf34 --- /dev/null +++ b/src/common/utils.cc @@ -0,0 +1,23 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file utils.cc + * \brief cpu implementation of util functions + */ + +#include "./utils.h" +#include "../operator/nn/cast_storage-inl.h" + +namespace mxnet { +namespace common { + + +template<> +void CastStorageDispatch(mshadow::Stream* s, + const NDArray& input, + const NDArray& output) { + mxnet::op::CastStorageComputeImpl(s, input, output); +} + + +} // namespace common +} // namespace mxnet diff --git a/src/common/utils.cu b/src/common/utils.cu new file mode 100644 index 000000000000..a249be5bb9f5 --- /dev/null +++ b/src/common/utils.cu @@ -0,0 +1,21 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file utils.cu + * \brief gpu implementation of util functions + */ + +#include "./utils.h" +#include "../operator/nn/cast_storage-inl.h" + +namespace mxnet { +namespace common { + +template<> +void CastStorageDispatch(mshadow::Stream* s, + const NDArray& input, + const NDArray& output) { + mxnet::op::CastStorageComputeImpl(s, input, output); +} + +} // namespace common +} // namespace mxnet diff --git a/src/common/utils.h b/src/common/utils.h index 789b4d14b9f2..19592affacac 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -6,7 +6,13 @@ #ifndef MXNET_COMMON_UTILS_H_ #define MXNET_COMMON_UTILS_H_ -#if DMLC_USE_CXX11 +#include +#include +#include +#include +#include +#include + #include #include #include @@ -14,15 +20,125 @@ #include #include #include -#endif // DMLC_USE_CXX11 - -#include -#include +#include namespace mxnet { + namespace common { -#if DMLC_USE_CXX11 +template +void CastStorageDispatch(mshadow::Stream* s, const NDArray& input, const NDArray& output); + +/* + * \brief Get the corresponding tensor blobs from default storage NDArrays. + * If any NDArray is of non-default storage, it is casted to default storage and + * the temporary NDArrays are stored in `temps`. When storage_fallback is false, + * and `MXNET_EXEC_STORAGE_FALLBACK` == 0, storage fallback is disallowed. + * \return true if any input is casted + */ +template +inline bool GetDefaultBlobs(const std::vector& nds, + std::vector *blobs, + std::vector *temps, + const OpContext& ctx, + bool storage_fallback = false) { + bool casted = false; + if (storage_fallback == false) { + storage_fallback = dmlc::GetEnv("MXNET_EXEC_STORAGE_FALLBACK", true); + } + for (auto& nd : nds) { + if (nd.storage_type() != kDefaultStorage) { + if (storage_fallback == false) { + LOG(FATAL) << "Storage type conversion detected during execution. " + << "You are probably executing an operator which " + << "doesn't support NDArray inputs with non-default storage."; + } + NDArray temp(nd.shape(), nd.ctx(), false); + CastStorageDispatch(ctx.get_stream(), nd, temp); + temps->push_back(temp); + blobs->push_back(temp.data()); + casted = true; + } else { + blobs->push_back(nd.data()); + } + } + return casted; +} + +/* + * \brief Cast the NDArrays in `src` according to the storage types of the NDArrays + * in `dst`. The ones with default storage in `dst` are ignored. + * When storage_fallback is false, and `MXNET_EXEC_STORAGE_FALLBACK` == 0, + * storage fallback is disallowed. + */ +template +inline void CastNonDefaultStorage(const std::vector& dst, + const std::vector& src, + const OpContext& ctx, + bool storage_fallback = false) { + CHECK_GE(dst.size(), src.size()); + if (src.size() == 0) return; + if (storage_fallback == false) { + storage_fallback = dmlc::GetEnv("MXNET_EXEC_STORAGE_FALLBACK", true); + } + size_t src_idx = 0; + for (size_t i = 0; i < dst.size(); i++) { + auto stype = dst[i].storage_type(); + if (stype != kDefaultStorage) { + if (storage_fallback == false) { + LOG(FATAL) << "Storage type conversion detected during execution. " + << "You are probably executing an operator which " + << "doesn't support NDArray inputs with non-default storage."; + } + CastStorageDispatch(ctx.get_stream(), src[src_idx++], dst[i]); + } + } + CHECK_EQ(src_idx, src.size()) << "Not all src NDArrays are casted"; +} + +// Check if any storage type is not default storage +inline bool ContainsNonDefaultStorage(const nnvm::StorageTypeVector& vstorage) { + for (auto& i : vstorage) { + if (i != kUndefinedStorage && i != kDefaultStorage) return true; + } + return false; +} + +inline bool ContainsDefaultStorage(const std::vector& ndarrays) { + for (auto &nd : ndarrays) { + if (nd.storage_type() == kDefaultStorage) { + return true; + } + } + return false; +} + +inline FCompute GetFCompute(const Op* op, Context ctx) { + static auto& fcompute_cpu = nnvm::Op::GetAttr("FCompute"); + static auto& fcompute_gpu = nnvm::Op::GetAttr("FCompute"); + if (ctx.dev_mask() == cpu::kDevMask) { + return fcompute_cpu.get(op, nullptr); + } else if (ctx.dev_mask() == gpu::kDevMask) { + return fcompute_gpu.get(op, nullptr); + } + LOG(FATAL) << "Unknown device mask"; + return nullptr; +} + +inline FComputeEx GetFComputeEx(const Op* op, Context ctx, int stype) { + static auto& fcpu = nnvm::Op::GetAttr("FComputeEx"); + static auto& fgpu = nnvm::Op::GetAttr("FComputeEx"); + if (stype == kDefaultStorage) return nullptr; + if (ctx.dev_mask() == cpu::kDevMask) { + return fcpu.get(op, nullptr); + } else if (ctx.dev_mask() == gpu::kDevMask) { + return fgpu.get(op, nullptr); + } + LOG(FATAL) << "Unknown device mask"; + return nullptr; +} + + // heuristic to dermine number of threads per GPU inline int GetNumThreadPerGPU() { // This is resource efficient option. @@ -37,6 +153,67 @@ inline int GetExecNumMatchColor() { return std::min(num_match_color, GetNumThreadPerGPU()); } +template +V ParallelAccumulate(const T* a, const int n, V start) { + V sum = start; +#pragma omp parallel for reduction(+:sum) + for (int i = 0; i < n; ++i) { + sum += a[i]; + } + return sum; +} + +/*! + * \brief + * Helper function for ParallelSort. + * DO NOT call this function directly. + * Use the interface ParallelSort instead. + * Ref: https://github.com/dmlc/difacto/blob/master/src/common/parallel_sort.h + */ +template +void ParallelSortHelper(RandomIt first, size_t len, + size_t grainsize, const Compare& comp) { + if (len < grainsize) { + std::sort(first, first+len, comp); + } else { + std::thread thr(ParallelSortHelper, first, len/2, grainsize, comp); + ParallelSortHelper(first+len/2, len - len/2, grainsize, comp); + thr.join(); + std::inplace_merge(first, first+len/2, first+len, comp); + } +} + +/*! + * \brief + * Sort the elements in the range [first, last) into the ascending order defined by + * the comparator comp. + * If the length of the range [first, last) is greater than a certain threshold, + * the range will be recursively divided into two and assign two threads + * to sort each half range. + * Ref: https://github.com/dmlc/difacto/blob/master/src/common/parallel_sort.h + */ +template +void ParallelSort(RandomIt first, RandomIt last, size_t num_threads, Compare comp) { + const auto num = std::distance(first, last); + size_t grainsize = std::max(num / num_threads + 5, static_cast(1024*16)); + ParallelSortHelper(first, num, grainsize, comp); +} + +/*! + * \brief + * Sort the elements in the range [first, last) into ascending order. + * The elements are compared using the default < operator. + * If the length of the range [first, last) is greater than a certain threshold, + * the range will be recursively divided into two and assign two threads + * to sort each half range. + * Ref: https://github.com/dmlc/difacto/blob/master/src/common/parallel_sort.h + */ +template +void ParallelSort(RandomIt first, RandomIt last, size_t num_threads) { + ParallelSort(first, last, num_threads, + std::less::value_type>()); +} + /*! * \brief Random Engine */ @@ -124,8 +301,6 @@ typename helper::UniqueIf::UnknownBound MakeUnique(size_t n) { template typename helper::UniqueIf::KnownBound MakeUnique(Args&&... args) = delete; -#endif // DMLC_USE_CXX11 - } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 16b55adc15e8..0d718df41c9e 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -8,11 +8,15 @@ #include #include #include "./exec_pass.h" +#include "../common/utils.h" #if MXNET_USE_MKL2017 == 1 #include #include "../operator/mkl/mkl_memory-inl.h" #include "../operator/mkl/mkl_util-inl.h" #endif + +#define EXEC_ATTACH_OP_DEBUG 0 + namespace mxnet { namespace op { @@ -24,9 +28,33 @@ namespace exec { // forward executor class ForwardOpExecutor : public OpExecutor { public: - void Run(RunContext rctx) override { + void Run(RunContext rctx, bool is_gpu) override { + using namespace common; op_ctx.run_ctx = rctx; - op_->Forward(op_ctx, in_data_, req, out_data_, aux_data_); + + // If any input ndarray contains non-default storage, + // we need to cast it to default storage and setup the tblobs again. For example, + // if any of the input ndarray changes, the updated value won't be reflected in the temporary + // ndarray with default storage. + in_data_.clear(); out_data_.clear(); aux_data_.clear(); + temp_in_.clear(); temp_out_.clear(); temp_aux_.clear(); + if (is_gpu) { +#if MXNET_USE_CUDA + GetDefaultBlobs(in_array_, &in_data_, &temp_in_, op_ctx); + GetDefaultBlobs(aux_array_, &aux_data_, &temp_aux_, op_ctx); + GetDefaultBlobs(out_array, &out_data_, &temp_out_, op_ctx); + op_->Forward(op_ctx, in_data_, req, out_data_, aux_data_); + CastNonDefaultStorage(out_array, temp_out_, op_ctx); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + GetDefaultBlobs(in_array_, &in_data_, &temp_in_, op_ctx); + GetDefaultBlobs(aux_array_, &aux_data_, &temp_aux_, op_ctx); + GetDefaultBlobs(out_array, &out_data_, &temp_out_, op_ctx); + op_->Forward(op_ctx, in_data_, req, out_data_, aux_data_); + CastNonDefaultStorage(out_array, temp_out_, op_ctx); + } #if MKL_EXPERIMENTAL == 1 mkl_tblobs_prv_to_cpu(in_data_); mkl_tblobs_prv_to_cpu(out_data_); @@ -35,18 +63,14 @@ class ForwardOpExecutor : public OpExecutor { } void Setup() override { - in_data_.clear(); aux_data_.clear(); + // We need to tell whether in NDArray is input or aux for (size_t i = 0; i < in_array.size(); ++i) { if (!std::binary_search(aux_index_.begin(), aux_index_.end(), i)) { - in_data_.push_back(in_array[i].data()); + in_array_.emplace_back(in_array[i]); } else { - aux_data_.push_back(in_array[i].data()); + aux_array_.emplace_back(in_array[i]); } } - out_data_.resize(out_array.size()); - std::transform(out_array.begin(), out_array.end(), out_data_.begin(), [](const NDArray& nd) { - return nd.data(); - }); } Operator::ExecType exec_type() const override { return op_->exec_type(); @@ -62,12 +86,14 @@ class ForwardOpExecutor : public OpExecutor { std::shared_ptr op_; std::vector aux_index_; std::vector in_data_, out_data_, aux_data_; + std::vector in_array_, aux_array_, temp_in_, temp_aux_, temp_out_; }; // backward executor class BackwardOpExecutor : public OpExecutor { public: - void Run(RunContext rctx) override { + void Run(RunContext rctx, bool is_gpu) override { + // TODO(haibin) support storage fallback for BackwardOpExecutor op_ctx.run_ctx = rctx; op_->Backward(op_ctx, out_grad_, in_data_, out_data_, req, in_grad_, aux_data_); @@ -135,23 +161,36 @@ class BackwardOpExecutor : public OpExecutor { // fcompute executor executor class FComputeExecutor : public OpExecutor { public: - void Run(RunContext rctx) override { + void Run(RunContext rctx, bool is_gpu) override { + using namespace common; op_ctx.run_ctx = rctx; - fcompute_(attrs_, op_ctx, in_data_, req, out_data_); + // setup blobs + // TODO(haibin) avoid repeating this if all inputs are already in default-storage. + { + in_data_.clear(); out_data_.clear(); + temp_in_.clear(); temp_out_.clear(); + if (is_gpu) { +#if MXNET_USE_CUDA + GetDefaultBlobs(in_array, &in_data_, &temp_in_, op_ctx); + GetDefaultBlobs(out_array, &out_data_, &temp_out_, op_ctx); + fcompute_(attrs_, op_ctx, in_data_, req, out_data_); + CastNonDefaultStorage(out_array, temp_out_, op_ctx); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + GetDefaultBlobs(in_array, &in_data_, &temp_in_, op_ctx); + GetDefaultBlobs(out_array, &out_data_, &temp_out_, op_ctx); + fcompute_(attrs_, op_ctx, in_data_, req, out_data_); + CastNonDefaultStorage(out_array, temp_out_, op_ctx); + } + } #if MKL_EXPERIMENTAL == 1 mkl_tblobs_prv_to_cpu(in_data_); mkl_tblobs_prv_to_cpu(out_data_); #endif } - void Setup() override { - in_data_.resize(in_array.size()); - out_data_.resize(out_array.size()); - auto get_blob = [](const NDArray& nd) { - return nd.data(); - }; - std::transform(in_array.begin(), in_array.end(), in_data_.begin(), get_blob); - std::transform(out_array.begin(), out_array.end(), out_data_.begin(), get_blob); - } + void Setup() override {} Operator::ExecType exec_type() const override { return Operator::kSync; } @@ -159,28 +198,41 @@ class FComputeExecutor : public OpExecutor { : fcompute_(fcompute), attrs_(attrs) { } - static FCompute GetFCompute(const Op* op, Context ctx) { - static auto& fcompute_cpu = nnvm::Op::GetAttr("FCompute"); - static auto& fcompute_gpu = nnvm::Op::GetAttr("FCompute"); - if (ctx.dev_mask() == cpu::kDevMask) { - return fcompute_cpu.get(op, nullptr); - } else if (ctx.dev_mask() == gpu::kDevMask) { - return fcompute_gpu.get(op, nullptr); - } else { - LOG(FATAL) << "Unknown device mask"; - return nullptr; - } - } - private: FCompute fcompute_; NodeAttrs attrs_; std::vector in_data_, out_data_; + std::vector temp_in_, temp_out_; +}; + +// fcomputend executor +class FComputeExExecutor : public OpExecutor { + public: + void Run(RunContext rctx, bool is_gpu) override { + op_ctx.run_ctx = rctx; + fcompute_(attrs_, op_ctx, in_data_, req, out_data_); + } + void Setup() override { + in_data_ = in_array; + out_data_ = out_array; + } + Operator::ExecType exec_type() const override { + return Operator::kSync; + } + explicit FComputeExExecutor(FComputeEx fcompute, const NodeAttrs& attrs) + : fcompute_(fcompute), attrs_(attrs) { + } + + private: + FComputeEx fcompute_; + NodeAttrs attrs_; + std::vector in_data_, out_data_; }; // pass to attach operator executors Graph AttachOpExecs(Graph g) { using nnvm::DTypeVector; + using nnvm::StorageTypeVector; using nnvm::ShapeVector; using nnvm::FMutateInputs; @@ -193,6 +245,7 @@ Graph AttachOpExecs(Graph g) { const auto& vctx = g.GetAttr("context"); const auto& saved_opr = g.GetAttr< std::unordered_map>>("saved_opr"); + const auto& dispatch_stypes = g.GetAttr("dispatch_stypes"); // get the graph const auto& idx = g.indexed_graph(); @@ -206,7 +259,12 @@ Graph AttachOpExecs(Graph g) { if (fmutate_inputs.count(inode.source->op())) { mutate_index = fmutate_inputs[inode.source->op()](inode.source->attrs); } - FCompute fcompute = FComputeExecutor::GetFCompute(inode.source->op(), vctx[i]); + FCompute fcompute = common::GetFCompute(inode.source->op(), vctx[i]); + FComputeEx fcompute_ex = + common::GetFComputeEx(inode.source->op(), vctx[i], dispatch_stypes[i]); +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "dispatch storage type = " << dispatch_stypes[i]; +#endif if (fcreate_layer_op.count(inode.source->op())) { std::vector ishape; std::vector itype; @@ -222,19 +280,33 @@ Graph AttachOpExecs(Graph g) { inode.source->attrs, vctx[i], ishape, itype)); } ret[i] = std::make_shared(opr, mutate_index); +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "ForwardOp for op " << inode.source->op()->name; +#endif } else if (is_layer_backward.get(inode.source->op(), false)) { CHECK_GE(inode.control_deps.size(), 1); uint32_t fwd_id = inode.control_deps[0]; CHECK(vctx[fwd_id] == vctx[i]); CHECK(ret[fwd_id] != nullptr); + CHECK_EQ(dispatch_stypes[i], kDefaultStorage) + << "BackwardOp doesn't handle non-default storage yet"; ret[i] = std::make_shared( dynamic_cast(ret[fwd_id].get())->op_, mxnet::op::OpPropGetOpProperty(inode.source->attrs), mutate_index); +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "BackwardOp for op " << inode.source->op()->name; +#endif + } else if (fcompute_ex != nullptr) { +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "FComputeEx for op " << inode.source->op()->name; +#endif + ret[i] = std::make_shared(fcompute_ex, inode.source->attrs); } else if (fcompute != nullptr) { +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "FCompute for op " << inode.source->op()->name; +#endif ret[i] = std::make_shared(fcompute, inode.source->attrs); - } else { - LOG(INFO) << "FCompute not registered " << inode.source->op()->name; } } g.attrs["op_execs"] = std::make_shared(ret); diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 8df6a3c5d3bb..20535be320d9 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -19,6 +19,12 @@ namespace exec { /*! \brief reuse graph definition */ using nnvm::Graph; +const int kBadStorageID = -1; +const int kExternalStorageID = -2; +const int kDynamicStorageID = -3; + +const int kNonDefaultStorage = -2; + /*! * \brief executor to execute an operator * This is a graph executor dependent interface @@ -26,7 +32,7 @@ using nnvm::Graph; */ class OpExecutor { public: - /*! \brief input arrays */ + /*! \brief input data arrays, which may be either input or aux */ std::vector in_array; /*! \brief output data arrays */ std::vector out_array; @@ -47,7 +53,7 @@ class OpExecutor { * This function call do not synchronize the stream. * \param rctx The runtime context passed in by environment. */ - virtual void Run(RunContext rctx) = 0; + virtual void Run(RunContext rctx, bool is_gpu) = 0; /*! \return the execution type */ virtual Operator::ExecType exec_type() const = 0; }; diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index d60c5e46e52c..de8411a7be95 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -12,6 +12,7 @@ #include "./exec_pass.h" #include "./graph_executor.h" #include "../engine/profiler.h" +#include "../common/utils.h" namespace mxnet { namespace exec { @@ -29,6 +30,30 @@ GraphExecutor::~GraphExecutor() { } } +inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype) { + // NDArray with default storage + if (stype == kDefaultStorage) { + NDArray ret(shape, ctx, false, dtype); + ret = 0; + return ret; + } + // NDArray with non-default storage. Storage allocation is always delayed. + return NDArray(stype, shape, ctx, true, dtype); +} + +inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype, + std::vector *vec) { + // NDArray with default storage + if (stype == kDefaultStorage) { + vec->emplace_back(shape, ctx, false, dtype); + vec->back() = 0; + } else { + // NDArray with non-default storage. Storage allocation is always delayed. + vec->emplace_back(stype, shape, ctx, true, dtype); + } +} void GraphExecutor::Forward(bool is_train) { RunOps(is_train, 0, num_forward_nodes_); } @@ -442,21 +467,25 @@ void GraphExecutor::Init(nnvm::Symbol symbol, data_entry_.resize(idx.num_node_entries()); nnvm::ShapeVector arg_shapes; nnvm::DTypeVector arg_dtypes; + nnvm::StorageTypeVector arg_stypes; for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const std::string& arg_name = idx[nid].source->attrs.name; + size_t eid = idx.entry_id(nid, 0); if (mutable_nodes.count(nid)) { CHECK_LT(aux_top, aux_states.size()); - data_entry_[idx.entry_id(nid, 0)] = aux_states[aux_top]; + data_entry_[eid] = aux_states[aux_top]; arg_shapes.push_back(aux_states[aux_top].shape()); arg_dtypes.push_back(aux_states[aux_top].dtype()); + arg_stypes.push_back(aux_states[aux_top].storage_type()); aux_state_map_.emplace(arg_name, aux_states[aux_top]); ++aux_top; } else { CHECK_LT(arg_top, in_args.size()); - data_entry_[idx.entry_id(nid, 0)] = in_args[arg_top]; + data_entry_[eid] = in_args[arg_top]; arg_shapes.push_back(in_args[arg_top].shape()); arg_dtypes.push_back(in_args[arg_top].dtype()); + arg_stypes.push_back(in_args[arg_top].storage_type()); in_arg_map_.emplace(arg_name, in_args[arg_top]); if (kNullOp != grad_req_types[arg_top]) { grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_store[arg_top]); @@ -464,6 +493,10 @@ void GraphExecutor::Init(nnvm::Symbol symbol, } ++arg_top; } +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign data entry\t" << eid << " as stype " + << data_entry_[eid].storage_type() << " (input)"; +#endif } // expand arg_shapes and arg_dtypes to contain backward inputs @@ -480,6 +513,8 @@ void GraphExecutor::Init(nnvm::Symbol symbol, HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("dtype")); } + // TODO(haibin) better error message for infer_storage + g = nnvm::pass::InferStorageType(g, arg_stypes, "__storage_type__"); // Initialize the rest attributes of the graph. // This function can be called by regular bind @@ -496,6 +531,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -513,22 +549,37 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const uint32_t eid = idx.entry_id(nid, 0); const TShape& inferred_shape = inferred_shapes[eid]; const int inferred_dtype = inferred_dtypes[eid]; + const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid]; const std::string& arg_name = idx[nid].source->attrs.name; if (mutable_nodes.count(nid)) { // aux_states - aux_state_vec->emplace_back(inferred_shape, aux_state_ctxes[aux_top], false, inferred_dtype); - aux_state_vec->back() = 0; + EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top], + inferred_dtype, aux_state_vec); data_entry_[eid] = aux_state_vec->back(); aux_state_map_.emplace(arg_name, aux_state_vec->back()); ++aux_top; +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign aux entry\t" << eid << "\t as stype " << inferred_stype; +#endif } else { // in_args - in_arg_vec->emplace_back(inferred_shape, in_arg_ctxes[arg_top], false, inferred_dtype); - in_arg_vec->back() = 0; + EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top], + inferred_dtype, in_arg_vec); data_entry_[eid] = in_arg_vec->back(); +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign data entry\t" << eid << "\tas stype " << inferred_stype; +#endif + // Get the storage type for grad if (kNullOp == grad_req_types[arg_top]) { arg_grad_vec->emplace_back(); } else { - arg_grad_vec->emplace_back(inferred_shape, arg_grad_ctxes[arg_top], false, inferred_dtype); - arg_grad_vec->back() = 0; + // Init based on storage type + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; + EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top], + inferred_dtype, arg_grad_vec); +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign grad entry\t" << grad_eid << "\tas stype " << grad_stype; +#endif grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); arg_grad_map_.emplace(arg_name, arg_grad_vec->back()); } @@ -540,33 +591,40 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, /*! * \brief If the requested ndarray's shape size is less than - * the corresponding shared_data_array's shape size, reuse - * the memory allocation; otherwise, create a zero ndarray. + * the corresponding shared_data_array's shape size and the + * storage type is default storage, reuse the memory allocation + * in shared_buffer; otherwise, create a zero ndarray. */ NDArray ReshapeOrCreate(const std::string& name, const TShape& dest_arg_shape, const int dest_arg_dtype, + const NDArrayStorageType dest_arg_stype, const Context& ctx, std::unordered_map* shared_buffer) { + if (dest_arg_dtype != kDefaultStorage) { + return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + } auto it = shared_buffer->find(name); if (it != shared_buffer->end()) { if (it->second.shape().Size() >= dest_arg_shape.Size()) { // memory can be reused CHECK_EQ(it->second.dtype(), dest_arg_dtype) << "Requested arg array's dtype does not match the reusable ndarray"; + CHECK_EQ(it->second.storage_type(), kDefaultStorage) + << "shared_buffer should only contain NDArrays with default storage type."; return it->second.Reshape(dest_arg_shape); } else { LOG(WARNING) << "Bucketing: data " << name << " has a shape " << dest_arg_shape << ", which is larger than already allocated shape " << it->second.shape() << ". Need to re-allocate. Consider putting default bucket key to be " << "the bucket taking the largest input for better memory sharing."; - it->second = NDArray(dest_arg_shape, ctx, false, dest_arg_dtype); - it->second = 0; + // the NDArrays in shared_buffer are guaranteed to be of default storage + it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); return it->second; } // arg_array.shape().Size() >= arg_shape.Size() } else { - auto p = shared_buffer->emplace(name, NDArray(dest_arg_shape, ctx, false, dest_arg_dtype)); - p.first->second = 0; - return p.first->second; + auto ret = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + shared_buffer->emplace(name, ret); + return ret; } // if (it != shared_buffer->end()) } @@ -579,6 +637,7 @@ NDArray ReshapeOrCreate(const std::string& name, void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -598,9 +657,12 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const uint32_t eid = idx.entry_id(nid, 0); const TShape& inferred_shape = inferred_shapes[eid]; const int inferred_dtype = inferred_dtypes[eid]; + const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid]; const std::string& arg_name = idx[nid].source->attrs.name; - if (mutable_nodes.count(nid)) { // aux_states - if (nullptr != shared_exec) { + // aux_states + if (mutable_nodes.count(nid)) { + if (nullptr != shared_exec && inferred_stype == kDefaultStorage && + shared_exec->aux_state_map().at(arg_name).storage_type() == kDefaultStorage) { const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name); CHECK_EQ(inferred_shape, aux_nd.shape()) << "Inferred shape does not match shared_exec.aux_array's shape." @@ -614,16 +676,18 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, << arg_name << " for the current executor"; aux_state_vec->emplace_back(aux_nd); } else { - aux_state_vec->emplace_back(inferred_shape, aux_state_ctxes[aux_top], - false, inferred_dtype); - aux_state_vec->back() = 0; + EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top], + inferred_dtype, aux_state_vec); } // if (has_shared_exec) data_entry_[eid] = aux_state_vec->back(); aux_state_map_.emplace(arg_name, aux_state_vec->back()); ++aux_top; - } else { // in_args + } else { // in_args and grad for in_args if (shared_arg_names.count(arg_name)) { // model parameter - if (nullptr != shared_exec) { + // model parameter + if (nullptr != shared_exec && inferred_stype == kDefaultStorage && + shared_exec->in_arg_map().at(arg_name).storage_type() == kDefaultStorage) { + // try to reuse memory from shared_exec const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name); CHECK_EQ(inferred_shape, in_arg_nd.shape()) << "Inferred shape does not match shared_exec.arg_array's shape" @@ -636,33 +700,43 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, " be resued for creating NDArray of the argument" << arg_name << " for the current executor"; in_arg_vec->emplace_back(in_arg_nd); - if (kNullOp == grad_req_types[arg_top]) { - arg_grad_vec->emplace_back(); - } else { + } else { + // doesn't have shared_exec, or non-default storage + EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top], + inferred_dtype, in_arg_vec); + } + // gradient for model parameter + if (kNullOp == grad_req_types[arg_top]) { + arg_grad_vec->emplace_back(); + } else { + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; + if (nullptr != shared_exec && grad_stype == kDefaultStorage && + shared_exec->arg_grad_map().at(arg_name).storage_type() == kDefaultStorage) { + // try to reuse memory from shared_exec arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name)); - grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); - } // if (kNullOp == grad_req_types[arg_top]) - } else { // !has shared_exec - in_arg_vec->emplace_back(inferred_shape, in_arg_ctxes[arg_top], false, inferred_dtype); - in_arg_vec->back() = 0; - if (kNullOp == grad_req_types[arg_top]) { - arg_grad_vec->emplace_back(); } else { - arg_grad_vec->emplace_back(inferred_shape, arg_grad_ctxes[arg_top], - false, inferred_dtype); - arg_grad_vec->back() = 0; - grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); - } // if (kNullOp == grad_req_types[arg_top]) - } // if (has_shared_exec) + EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top], + inferred_dtype, arg_grad_vec); + } + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + } } else { // !shared_arg_names.count(arg_name) + // model parameter in_arg_vec->emplace_back(ReshapeOrCreate(arg_name, inferred_shape, inferred_dtype, - in_arg_ctxes[arg_top], shared_buffer)); + inferred_stype, in_arg_ctxes[arg_top], + shared_buffer)); + // gradient for model parameter if (kNullOp == grad_req_types[arg_top]) { arg_grad_vec->emplace_back(); } else { + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; arg_grad_vec->emplace_back(ReshapeOrCreate("grad of " + arg_name, inferred_shape, - inferred_dtype, arg_grad_ctxes[arg_top], - shared_buffer)); + inferred_dtype, grad_stype, + arg_grad_ctxes[arg_top], shared_buffer)); grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); } // if (kNullOp == grad_req_types[arg_top]) } // if (shared_arg_names.count(arg_name)) @@ -685,14 +759,35 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, Executor* shared_exec, const nnvm::NodeEntryMap& feed_dict) { const auto& idx = g.indexed_graph(); + // dispatch based on stype per operator + const auto& vstorage_type = g.GetAttr("storage_type"); + nnvm::StorageTypeVector dispatch_stypes(idx.num_nodes(), kUndefinedStorage); + for (size_t nid = 0; nid < idx.num_nodes(); nid++) { + const auto& inode = idx[nid]; + auto num_outputs = inode.source->num_outputs(); + auto num_inputs = inode.inputs.size(); + nnvm::StorageTypeVector vs(num_inputs + num_outputs, kUndefinedStorage); + for (size_t i = 0; i < num_inputs; i++) { + auto e = inode.inputs[i]; + vs[i] = vstorage_type[idx.entry_id(e)]; + CHECK_NE(vs[i], kUndefinedStorage); + } + for (uint32_t i = 0; i < num_outputs; ++i) { + uint32_t eid = idx.entry_id(nid, i); + vs[i + num_inputs] = vstorage_type[eid]; + } + bool contains_non_default = common::ContainsNonDefaultStorage(vs); + dispatch_stypes[nid] = contains_non_default ? kNonDefaultStorage : kDefaultStorage; + } + g.attrs["dispatch_stypes"] = std::make_shared(std::move(dispatch_stypes)); + + // data entries for output gradients for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { data_entry_[idx.entry_id(idx.outputs()[j])] = grad_store_[j - num_forward_outputs_].second; } { // memory allocator - const int kBadStorageID = -1; - const int kExternalStorageID = -2; nnvm::StorageVector arg_storage_id(idx.num_node_entries(), kBadStorageID); for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { arg_storage_id[idx.entry_id(idx.outputs()[j])] = kExternalStorageID; @@ -702,6 +797,9 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, data_entry_[eid] = kv.second; arg_storage_id[eid] = kExternalStorageID; } + for (size_t i = 0; i < idx.num_node_entries(); i++) { + if (vstorage_type[i] != kDefaultStorage) arg_storage_id[i] = kDynamicStorageID; + } g.attrs["storage"] = std::make_shared(std::move(arg_storage_id)); g = nnvm::ApplyPass(g, "PlanMemory"); } @@ -759,6 +857,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const std::vector& aux_state_ctxes, const std::unordered_map& arg_shape_map, const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, const std::vector& grad_req_types, const std::unordered_set& shared_arg_names, std::vector* in_arg_vec, @@ -778,6 +877,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const nnvm::IndexedGraph& idx = g.indexed_graph(); nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); + nnvm::DTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const std::string& name = idx[nid].source->attrs.name; @@ -789,6 +889,10 @@ void GraphExecutor::Init(nnvm::Symbol symbol, if (arg_dtype_map.end() != it2) { arg_dtypes[i] = it2->second; } + auto it3 = arg_stype_map.find(name); + if (arg_stype_map.end() != it3) { + arg_stypes[i] = it3->second; + } } g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); if (g.GetAttr("shape_num_unknown_nodes") != 0U) { @@ -801,17 +905,21 @@ void GraphExecutor::Init(nnvm::Symbol symbol, HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("dtype")); } + // TODO(jun/haibin) check if InferShape is successful, and give warnings instead of segfault later + g = nnvm::pass::InferStorageType(g, arg_stypes, "__storage_type__"); // Create in_args, arg_grads, and aux_states using // the inferred shapes and dtypes. if (nullptr == shared_buffer) { // regular simple bind InitArguments(idx, g.GetAttr("shape"), g.GetAttr("dtype"), + g.GetAttr("storage_type"), in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec); } else { // simple bind using shared data arrays and shared_exec InitArguments(idx, g.GetAttr("shape"), g.GetAttr("dtype"), + g.GetAttr("storage_type"), in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, grad_req_types, shared_arg_names, shared_exec, shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec); @@ -864,6 +972,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, // initialize the memory of each entries void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { using nnvm::DTypeVector; + using nnvm::StorageTypeVector; using nnvm::ShapeVector; using nnvm::StorageVector; // get the graph @@ -872,20 +981,29 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { const auto& vdtype = graph_.GetAttr("dtype"); const auto& vshape = graph_.GetAttr("shape"); const auto& vstorage = graph_.GetAttr("storage_id"); + const auto& vstorage_type = graph_.GetAttr("storage_type"); const auto& vctx = graph_.GetAttr("context"); CHECK_EQ(idx.num_node_entries(), vshape.size()); CHECK_EQ(idx.num_node_entries(), vdtype.size()); CHECK_EQ(idx.num_node_entries(), vstorage.size()); CHECK_EQ(data_entry_.size(), vshape.size()); std::vector data_context(idx.num_node_entries()); + std::vector data_storage_type(idx.num_node_entries(), kUndefinedStorage); for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t i = 0; i < idx[nid].source->num_outputs(); ++i) { - data_context[idx.entry_id(nid, i)] = vctx[nid]; + auto eid = idx.entry_id(nid, i); + data_context[eid] = vctx[nid]; + CHECK_NE(vstorage_type[nid], kUndefinedStorage); + data_storage_type[eid] = (NDArrayStorageType) vstorage_type[nid]; } } // information about the pool - using PoolEntry = std::pair; + struct PoolEntry { + Context ctx; + size_t bytes; + NDArrayStorageType stype; + }; std::vector pool_info; // assign array to head gradient @@ -893,26 +1011,36 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { uint32_t nid = idx.input_nodes().at(i); uint32_t oid = head_grad_map_.at(idx[nid].source); uint32_t eid = idx.entry_id(idx.outputs()[oid]); + NDArrayStorageType stype = (NDArrayStorageType) vstorage_type[eid]; CHECK_NE(vshape[eid].ndim(), 0U); CHECK_NE(vdtype[eid], -1); - data_entry_[idx.entry_id(nid, 0)] = - NDArray(vshape[eid], data_context[eid], false, vdtype[eid]); + auto data_eid = idx.entry_id(nid, 0); + // initialize based on storage_type + if (stype != kDefaultStorage) { + data_entry_[data_eid] = NDArray(stype, vshape[eid], data_context[eid], true, vdtype[eid]); + } else { + data_entry_[data_eid] = NDArray(vshape[eid], data_context[eid], false, vdtype[eid]); + } +#if EXECUTOR_DEBUG + LOG(INFO) << "\tinit head_g entry\t" << data_eid << "\tas stype " << stype; +#endif } // get maximum bytes in each pool for (size_t i = 0; i < vshape.size(); ++i) { if (!data_entry_[i].is_none()) continue; size_t bytes = vshape[i].Size() * mshadow::mshadow_sizeof(vdtype[i]); int storage_id = vstorage[i]; + // skip pool allocation for kBadStorageID, kExternalStorageID and kDynamicStorageID if (storage_id < 0) continue; size_t sid = static_cast(storage_id); if (sid >= pool_info.size()) { - pool_info.resize(sid + 1, PoolEntry{Context::CPU(), size_t(0)}); + pool_info.resize(sid + 1, PoolEntry{Context::CPU(), size_t(0), kUndefinedStorage}); } PoolEntry& info = pool_info[sid]; - if (info.second == 0) { - info = PoolEntry{data_context[i], bytes}; + if (info.bytes == 0) { + info = PoolEntry{data_context[i], bytes, data_storage_type[i]}; } else { - info.second = std::max(info.second, bytes); + info.bytes = std::max(info.bytes, bytes); } } // construct the re-use pool, if needed @@ -933,13 +1061,14 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { sorted_pool_index.push_back(i); } auto pool_comparator = [&pool_info](int lhs, int rhs){ - return pool_info[lhs].second > pool_info[rhs].second; + return pool_info[lhs].bytes > pool_info[rhs].bytes; }; std::sort(sorted_pool_index.begin(), sorted_pool_index.end(), pool_comparator); for (size_t i : sorted_pool_index) { - const Context& ctx = pool_info[i].first; - size_t bytes = pool_info[i].second; + const Context& ctx = pool_info[i].ctx; + size_t bytes = pool_info[i].bytes; + NDArrayStorageType storage_type = pool_info[i].stype; bool allocated = false; for (auto it = free_pool.lower_bound(bytes); it != free_pool.end(); ++it) { if (it->second.ctx() == ctx && it->first >= bytes) { @@ -964,15 +1093,22 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { } CHECK_EQ(data_pool_.size(), pool_info.size()); // assign the data entries - for (size_t i = 0; i < data_entry_.size(); ++i) { // avoid pre-allocated arrays if (!data_entry_[i].is_none()) continue; // assign allocated array by storage id int storage_id = vstorage[i]; - CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet"; - const NDArray& src = data_pool_.at(storage_id); - data_entry_[i] = src.AsArray(vshape[i], vdtype[i]); + auto storage_type = (NDArrayStorageType) vstorage_type[i]; + if (storage_type == kDefaultStorage) { + CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet"; + const NDArray& src = data_pool_.at(storage_id); + data_entry_[i] = src.AsArray(vshape[i], vdtype[i]); + } else { + data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i]); + } +#if EXECUTOR_DEBUG + LOG(INFO) << "\tinit data entry\t" << i << "\tas stype " << storage_type; +#endif } } @@ -987,11 +1123,28 @@ void GraphExecutor::InitCachedOps() { const auto& vctx = graph_.GetAttr("context"); const auto& addto_entry = graph_.GetAttr >("addto_entry"); const auto& skip_plus_node = graph_.GetAttr >("skip_plus_node"); + const auto& vstorage_type = graph_.GetAttr("storage_type"); op_nodes_.resize(idx.num_nodes()); // setup the array and requirements. for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; +#if EXECUTOR_DEBUG + if (inode.source->is_variable()) { + LOG(INFO) << "node " << nid << " var"; + } else { + LOG(INFO) << "node " << nid << " " << inode.source->attrs.op->name; + auto exec = op_execs[nid]; + for (const auto& e : inode.inputs) { + auto eid = idx.entry_id(e); + LOG(INFO) << "\t\tinput " << eid << " stype: " << vstorage_type[eid]; + } + for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { + uint32_t eid = idx.entry_id(nid, index); + LOG(INFO) << "\t\toutput " << eid << " stype: " << vstorage_type[eid]; + } + } +#endif if (inode.source->is_variable()) continue; #if MXNET_USE_PROFILER op_nodes_[nid].opr_name = inode.source->op()->name.c_str(); @@ -1068,7 +1221,7 @@ void GraphExecutor::InitCachedOps() { if (is_async) { exec->op_ctx.async_on_complete = on_complete; } - exec->Run(ctx); + exec->Run(ctx, is_gpu); // call on complete only if it is async op if (!is_async) { if (is_gpu) { @@ -1213,6 +1366,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning; #else bool profiling = false; +#endif +#if EXECUTOR_DEBUG + LOG(INFO) << "Run node " << nid << " - " << seg_op.topo_end - 1; #endif Engine::Get()->Push(seg_op.opr, seg_op.ctx, 0, profiling); nid = seg_op.topo_end - 1; @@ -1225,6 +1381,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { if (op_nodes_[nid].skip_exec_node) continue; opnode.exec->op_ctx.is_train = is_train; if (opnode.exec->exec_type() == Operator::kCrossDeviceCopy) { +#if EXECUTOR_DEBUG + LOG(INFO) << "Run node " << nid << " for CrossDeviceCopy"; +#endif CHECK_EQ(inode.inputs.size(), 1U); CHECK_EQ(opnode.exec->in_array.size(), 1U); CHECK_EQ(opnode.exec->out_array.size(), 1U); @@ -1234,6 +1393,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning; #else bool profiling = false; +#endif +#if EXECUTOR_DEBUG + LOG(INFO) << "Run node " << nid; #endif Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling); } else { @@ -1298,7 +1460,7 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, RunContext ctx, Engine::CallbackOnComplete on_complete) { // Run all opr in the sub-graph for (auto &exec : exec_list) { - exec->Run(ctx); + exec->Run(ctx, is_gpu); } if (is_gpu) { #if MXNET_USE_CUDA @@ -1333,6 +1495,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, const std::vector& aux_state_ctxes, const std::unordered_map& arg_shape_map, const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, const std::vector& grad_req_types, const std::unordered_set& shared_arg_names, std::vector* in_args, @@ -1343,7 +1506,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, auto exec = new exec::GraphExecutor(); exec->Init(symbol, default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, - arg_shape_map, arg_dtype_map, + arg_shape_map, arg_dtype_map, arg_stype_map, grad_req_types, shared_arg_names, in_args, arg_grads, aux_states, shared_buffer, shared_exec); diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index d5a4e8c3aa6c..308eddba8b80 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -19,6 +19,8 @@ #include #include "./exec_pass.h" +#define EXECUTOR_DEBUG 0 + namespace mxnet { using NodeOperatorMap = std::unordered_map& aux_state_ctxes, const std::unordered_map& arg_shape_map, const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, const std::vector& grad_req_types, const std::unordered_set& shared_arg_names, std::vector* in_arg_vec, @@ -126,6 +129,7 @@ class GraphExecutor : public Executor { void InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -138,6 +142,7 @@ class GraphExecutor : public Executor { void InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -186,7 +191,8 @@ class GraphExecutor : public Executor { std::vector op_nodes_; // internal data entry of each node std::vector data_entry_; - // internal data pool of allocated entries + // internal data pool of allocated entries. + // these allocated entries can be used for static memory sharing between executors. std::vector data_pool_; // output arrays std::vector output_arrays_; diff --git a/src/executor/inplace_addto_detect_pass.cc b/src/executor/inplace_addto_detect_pass.cc index 75a2608313aa..1a0bc9cb40a6 100644 --- a/src/executor/inplace_addto_detect_pass.cc +++ b/src/executor/inplace_addto_detect_pass.cc @@ -44,6 +44,8 @@ Graph DetectInplaceAddTo(Graph g) { uint32_t eid_rhs = idx.entry_id(inode.inputs[1]); if (ref_count[eid_rhs] != 1) continue; if (inode.inputs[0].node_id >= inode.inputs[1].node_id) continue; + // TODO(haibin) support inplace addto for Dynamic Storage + if (storage_id[eid_rhs] == kDynamicStorageID) continue; CHECK_NE(storage_id[eid_rhs], sid); storage_id[eid_rhs] = sid; addto_entry[eid_rhs] = 1; diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index a51e24503785..91488c065033 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -23,7 +23,7 @@ namespace io { class BatchLoader : public IIterator { public: explicit BatchLoader(IIterator *base): - base_(base), head_(1), num_overflow_(0) { + head_(1), num_overflow_(0), base_(base) { } virtual ~BatchLoader(void) { @@ -34,7 +34,7 @@ class BatchLoader : public IIterator { std::vector > kwargs_left; // init batch param, it could have similar param with kwargs_left = param_.InitAllowUnknown(kwargs); - // Init space for out_ + // Init space for out out_.inst_index = new unsigned[param_.batch_size]; out_.batch_size = param_.batch_size; out_.data.clear(); @@ -51,6 +51,7 @@ class BatchLoader : public IIterator { } head_ = 1; } + virtual bool Next(void) { out_.num_batch_padd = 0; out_.batch_size = param_.batch_size; @@ -110,23 +111,25 @@ class BatchLoader : public IIterator { return out_; } - private: + protected: /*! \brief batch parameters */ BatchParam param_; /*! \brief output data */ TBlobBatch out_; - /*! \brief base iterator */ - IIterator *base_; /*! \brief on first */ int head_; /*! \brief number of overflow instances that readed in round_batch mode */ int num_overflow_; + /*! \brief tensor to hold data */ + std::vector data_; + + private: + /*! \brief base iterator */ + IIterator *base_; /*! \brief data shape */ std::vector shape_; /*! \brief unit size */ std::vector unit_size_; - /*! \brief tensor to hold data */ - std::vector data_; // initialize the data holder by using from the first batch. inline void InitData(const DataInst& first_batch) { shape_.resize(first_batch.data.size()); diff --git a/src/io/iter_libsvm.cc b/src/io/iter_libsvm.cc new file mode 100644 index 000000000000..04dcf289a020 --- /dev/null +++ b/src/io/iter_libsvm.cc @@ -0,0 +1,258 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file iter_libsvm.cc + * \brief define a LibSVM Reader to read in arrays + */ +#include +#include +#include +#include +#include +#include "./iter_sparse_prefetcher.h" +#include "./iter_sparse_batchloader.h" + +namespace mxnet { +namespace io { +// LibSVM parameters +struct LibSVMIterParam : public dmlc::Parameter { + /*! \brief path to data libsvm file */ + std::string data_libsvm; + /*! \brief data shape */ + TShape data_shape; + /*! \brief path to label libsvm file */ + std::string label_libsvm; + /*! \brief label shape */ + TShape label_shape; + // declare parameters + DMLC_DECLARE_PARAMETER(LibSVMIterParam) { + DMLC_DECLARE_FIELD(data_libsvm) + .describe("The input LibSVM file or a directory path."); + DMLC_DECLARE_FIELD(data_shape) + .describe("The shape of one example."); + DMLC_DECLARE_FIELD(label_libsvm).set_default("NULL") + .describe("The input LibSVM file or a directory path. " + "If NULL, all labels will be read from ``data_libsvm``."); + index_t shape1[] = {1}; + DMLC_DECLARE_FIELD(label_shape).set_default(TShape(shape1, shape1 + 1)) + .describe("The shape of one label."); + } +}; + +class LibSVMIter: public SparseIIterator { + public: + LibSVMIter() {} + virtual ~LibSVMIter() {} + + // intialize iterator loads data in + virtual void Init(const std::vector >& kwargs) { + param_.InitAllowUnknown(kwargs); + CHECK_EQ(param_.data_shape.ndim(), 1) << "dimension of data_shape is expected to be 1"; + data_parser_.reset(dmlc::Parser::Create(param_.data_libsvm.c_str(), + 0, 1, "libsvm")); + if (param_.label_libsvm != "NULL") { + label_parser_.reset(dmlc::Parser::Create(param_.label_libsvm.c_str(), + 0, 1, "libsvm")); + CHECK_GT(param_.label_shape.Size(), 1) + << "label_shape is not expected to be (1,) when param_.label_libsvm is set."; + } else { + CHECK_EQ(param_.label_shape.Size(), 1) + << "label_shape is expected to be (1,) when param_.label_libsvm is NULL"; + } + // both data and label are of CSRStorage in libsvm format + if (param_.label_shape.Size() > 1) { + out_.data.resize(6); + } else { + // only data is of CSRStorage in libsvm format. + out_.data.resize(4); + } + } + + virtual void BeforeFirst() { + data_parser_->BeforeFirst(); + if (label_parser_.get() != nullptr) { + label_parser_->BeforeFirst(); + } + data_ptr_ = label_ptr_ = 0; + data_size_ = label_size_ = 0; + inst_counter_ = 0; + end_ = false; + } + + virtual bool Next() { + if (end_) return false; + while (data_ptr_ >= data_size_) { + if (!data_parser_->Next()) { + end_ = true; return false; + } + data_ptr_ = 0; + data_size_ = data_parser_->Value().size; + } + out_.index = inst_counter_++; + CHECK_LT(data_ptr_, data_size_); + const auto data_row = data_parser_->Value()[data_ptr_++]; + // data, indices and indptr + out_.data[0] = AsDataBlob(data_row); + out_.data[1] = AsIdxBlob(data_row); + out_.data[2] = AsIndPtrPlaceholder(data_row); + + if (label_parser_.get() != nullptr) { + while (label_ptr_ >= label_size_) { + CHECK(label_parser_->Next()) + << "Data LibSVM's row is smaller than the number of rows in label_libsvm"; + label_ptr_ = 0; + label_size_ = label_parser_->Value().size; + } + CHECK_LT(label_ptr_, label_size_); + const auto label_row = label_parser_->Value()[label_ptr_++]; + // data, indices and indptr + out_.data[3] = AsDataBlob(label_row); + out_.data[4] = AsIdxBlob(label_row); + out_.data[5] = AsIndPtrPlaceholder(label_row); + } else { + out_.data[3] = AsScalarLabelBlob(data_row); + } + return true; + } + + virtual const DataInst &Value(void) const { + return out_; + } + + virtual const NDArrayStorageType GetStorageType(bool is_data) const { + if (is_data) return kCSRStorage; + return param_.label_shape.Size() > 1 ? kCSRStorage : kDefaultStorage; + } + + virtual const TShape GetShape(bool is_data) const { + if (is_data) return param_.data_shape; + return param_.label_shape; + } + + private: + inline TBlob AsDataBlob(const dmlc::Row& row) { + const real_t* ptr = row.value; + TShape shape(mshadow::Shape1(row.length)); + return TBlob((real_t*) ptr, shape, cpu::kDevMask); // NOLINT(*) + } + + inline TBlob AsIdxBlob(const dmlc::Row& row) { + const uint64_t* ptr = row.index; + TShape shape(mshadow::Shape1(row.length)); + return TBlob((int64_t*) ptr, shape, cpu::kDevMask, mshadow::kInt64); // NOLINT(*) + } + + inline TBlob AsIndPtrPlaceholder(const dmlc::Row& row) { + return TBlob(nullptr, mshadow::Shape1(0), cpu::kDevMask, mshadow::kInt64); + } + + inline TBlob AsScalarLabelBlob(const dmlc::Row& row) { + const real_t* ptr = row.label; + return TBlob((real_t*) ptr, mshadow::Shape1(1), cpu::kDevMask); // NOLINT(*) + } + + LibSVMIterParam param_; + // output instance + DataInst out_; + // internal instance counter + unsigned inst_counter_{0}; + // at end + bool end_{false}; + // label parser + size_t label_ptr_{0}, label_size_{0}; + size_t data_ptr_{0}, data_size_{0}; + std::unique_ptr > label_parser_; + std::unique_ptr > data_parser_; +}; + + +DMLC_REGISTER_PARAMETER(LibSVMIterParam); + +MXNET_REGISTER_IO_ITER(LibSVMIter) +.describe(R"code(Returns the LibSVM file iterator. This iterator is experimental and +should be used with care. + +The input data is similar to libsvm file format, except that the indices are expected to be +zero-based instead of one-based. Details of the libsvm format are available at +`https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/` + +In this function, the `data_shape` parameter is used to set the shape of each line of the data. +The dimension of both `data_shape` and `label_shape` are expected to be 1. + +When `label_libsvm` is set to ``NULL``, both data and label are read from the same file specified +by `data_libsvm`. Otherwise, data is read from `data_libsvm` and label from `label_libsvm`, +in this case, if `data_libsvm` contains label, it will ignored. + +The `LibSVMIter` only support `round_batch` parameter set to ``True`` for now. So, if `batch_size` +is 3 and there are 4 total rows in libsvm file, 2 more examples +are consumed at the first round. If `reset` function is called after first round, +the call is ignored and remaining examples are returned in the second round. + +If ``data_libsvm = 'data/'`` is set, then all the files in this directory will be read. + +Examples:: + + // Contents of libsvm file ``data.t``. + 1.0 0:0.5 2:1.2 + -2.0 + -3.0 0:0.6 1:2.4 2:1.2 + 4 2:-1.2 + + // Creates a `LibSVMIter` with `batch_size`=3. + LibSVMIter = mx.io.LibSVMIter(data_libsvm = 'data.t', data_shape = (3,), + batch_size = 3) + + // The first batch (data and label) + [[ 0.5 0. 1.2 ] + [ 0. 0. 0. ] + [ 0.6 2.4 1.2 ]] + + [ 1. -2. -3.] + + // The second batch (data and label) + [[ 0. 0. -1.2 ] + [ 0.5 0. 1.2 ] + [ 0. 0. 0. ]] + + [ 4. 1. -2.] + + // Contents of libsvm file ``label.t`` + 1.0 + -2.0 0:0.125 + -3.0 2:1.2 + 4 1:1.0 2:-1.2 + + // Creates a `LibSVMIter` with specified label file + LibSVMIter = mx.io.LibSVMIter(data_libsvm = 'data.t', data_shape = (3,), + label_libsvm = 'label.t', label_shape = (3,), batch_size = 3) + + // Two batches of data read from the above iterator are as follows(data and label): + // The first batch + [[ 0.5 0. 1.2 ] + [ 0. 0. 0. ] + [ 0.6 2.4 1.2 ]] + + [[ 0. 0. 0. ] + [ 0.125 0. 0. ] + [ 0. 0. 1.2 ]] + + // The second batch + [[ 0. 0. -1.2 ] + [ 0.5 0. 1.2 ] + [ 0. 0. 0. ]] + + [[ 0. 1. -1.2 ] + [ 0. 0. 0. ] + [ 0.125 0. 0. ]] + +)code" ADD_FILELINE) +.add_arguments(LibSVMIterParam::__FIELDS__()) +.add_arguments(BatchParam::__FIELDS__()) +.add_arguments(PrefetcherParam::__FIELDS__()) +.set_body([]() { + return new SparsePrefetcherIter( + new SparseBatchLoader( + new LibSVMIter())); + }); + +} // namespace io +} // namespace mxnet diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 9050ef2d1b38..3eb85b12c077 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -28,8 +28,7 @@ namespace io { class PrefetcherIter : public IIterator { public: explicit PrefetcherIter(IIterator* base) - : loader_(base), out_(nullptr) { - } + : loader_(base), out_(nullptr) {} ~PrefetcherIter() { while (recycle_queue_.size() != 0) { @@ -38,21 +37,24 @@ class PrefetcherIter : public IIterator { delete batch; } delete out_; - iter_.Destroy(); + iter.Destroy(); } - virtual void Init(const std::vector >& kwargs) { + void InitParams(const std::vector >& kwargs) { std::vector > kwargs_left; // init image rec param kwargs_left = param_.InitAllowUnknown(kwargs); - // use the kwarg to init batch loader - loader_->Init(kwargs); // maximum prefetch threaded iter internal size const int kMaxPrefetchBuffer = 16; // init thread iter - iter_.set_max_capacity(kMaxPrefetchBuffer); + iter.set_max_capacity(kMaxPrefetchBuffer); + } - iter_.Init([this](DataBatch **dptr) { + virtual void Init(const std::vector >& kwargs) { + InitParams(kwargs); + // use the kwarg to init batch loader + loader_->Init(kwargs); + iter.Init([this](DataBatch **dptr) { if (!loader_->Next()) return false; const TBlobBatch& batch = loader_->Value(); if (*dptr == nullptr) { @@ -91,7 +93,7 @@ class PrefetcherIter : public IIterator { } virtual void BeforeFirst(void) { - iter_.BeforeFirst(); + iter.BeforeFirst(); } virtual bool Next(void) { @@ -106,9 +108,9 @@ class PrefetcherIter : public IIterator { arr.WaitToWrite(); } recycle_queue_.pop(); - iter_.Recycle(&old_batch); + iter.Recycle(&old_batch); } - return iter_.Next(&out_); + return iter.Next(&out_); } virtual const DataBatch &Value(void) const { return *out_; @@ -117,16 +119,16 @@ class PrefetcherIter : public IIterator { protected: /*! \brief prefetcher parameters */ PrefetcherParam param_; - /*! \brief internal batch loader */ - std::unique_ptr > loader_; + /*! \brief backend thread */ + dmlc::ThreadedIter iter; private: + /*! \brief internal batch loader */ + std::unique_ptr > loader_; /*! \brief output data */ DataBatch *out_; /*! \brief queue to be recycled */ std::queue recycle_queue_; - /*! \brief backend thread */ - dmlc::ThreadedIter iter_; }; } // namespace io } // namespace mxnet diff --git a/src/io/iter_sparse.h b/src/io/iter_sparse.h new file mode 100644 index 000000000000..24e3d81ee553 --- /dev/null +++ b/src/io/iter_sparse.h @@ -0,0 +1,27 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file iter_sparse.h + * \brief mxnet sparse data iterator + */ +#ifndef MXNET_IO_ITER_SPARSE_H_ +#define MXNET_IO_ITER_SPARSE_H_ + +#include +#include + +namespace mxnet { +/*! + * \brief iterator type + * \param DType data type + */ +template +class SparseIIterator : public IIterator { + public: + /*! \brief storage type of the data or label */ + virtual const NDArrayStorageType GetStorageType(bool is_data) const = 0; + /*! \brief shape of the data or label */ + virtual const TShape GetShape(bool is_data) const = 0; +}; // class SparseIIterator + +} // namespace mxnet +#endif // MXNET_IO_ITER_SPARSE_H_ diff --git a/src/io/iter_sparse_batchloader.h b/src/io/iter_sparse_batchloader.h new file mode 100644 index 000000000000..a89f21acb2a4 --- /dev/null +++ b/src/io/iter_sparse_batchloader.h @@ -0,0 +1,184 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file iter_sparse_batchloader.h + * \brief define a batch adapter to create sparse tblob batch + */ +#ifndef MXNET_IO_ITER_SPARSE_BATCHLOADER_H_ +#define MXNET_IO_ITER_SPARSE_BATCHLOADER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "./inst_vector.h" +#include "./image_iter_common.h" +#include "./iter_batchloader.h" +#include "./iter_sparse.h" + +namespace mxnet { +namespace io { + +/*! \brief create a batch iterator from single instance iterator */ +class SparseBatchLoader : public BatchLoader, public SparseIIterator { + public: + explicit SparseBatchLoader(SparseIIterator *base): + BatchLoader(base), sparse_base_(base) { + } + + virtual ~SparseBatchLoader(void) {} + + inline void Init(const std::vector >& kwargs) { + BatchLoader::Init(kwargs); + data_stype_ = sparse_base_->GetStorageType(true); + label_stype_ = sparse_base_->GetStorageType(false); + if (param_.round_batch == 0) { + LOG(FATAL) << "sparse batch loader doesn't support round_batch == false yet"; + } + } + + virtual void BeforeFirst(void) { + BatchLoader::BeforeFirst(); + } + + virtual bool Next(void) { + out_.num_batch_padd = 0; + out_.batch_size = param_.batch_size; + this->head_ = 0; + // if overflown from previous round, directly return false, until before first is called + if (num_overflow_ != 0) return false; + index_t top = 0; + inst_cache_.clear(); + while (sparse_base_->Next()) { + inst_cache_.emplace_back(sparse_base_->Value()); + if (inst_cache_.size() >= param_.batch_size) break; + } + // no more data instance + if (inst_cache_.size() == 0) { + return false; + } + if (inst_cache_.size() < param_.batch_size) { + CHECK_GT(param_.round_batch, 0); + num_overflow_ = 0; + sparse_base_->BeforeFirst(); + for (; inst_cache_.size() < param_.batch_size; ++num_overflow_) { + CHECK(sparse_base_->Next()) << "number of input must be bigger than batch size"; + inst_cache_.emplace_back(sparse_base_->Value()); + } + } + out_.num_batch_padd = num_overflow_; + CHECK_EQ(inst_cache_.size(), param_.batch_size); + this->InitDataFromBatch(); + for (size_t j = 0; j < inst_cache_.size(); j++) { + const auto& d = inst_cache_[j]; + out_.inst_index[top] = d.index; + // TODO(haibin) double check the type? + int64_t unit_size = 0; + for (size_t i = 0; i < d.data.size(); ++i) { + // indptr tensor + if (IsIndPtr(i)) { + auto indptr = data_[i].get(); + if (j == 0) indptr[0] = 0; + indptr[j + 1] = indptr[j] + unit_size; + offsets_[i] = j; + } else { + // indices and values tensor + unit_size = d.data[i].shape_.Size(); + MSHADOW_TYPE_SWITCH(data_[i].type_flag_, DType, { + const auto begin = offsets_[i]; + const auto end = offsets_[i] + unit_size; + mshadow::Copy(data_[i].get().Slice(begin, end), + d.data[i].get_with_shape(mshadow::Shape1(unit_size))); + }); + offsets_[i] += unit_size; + } + } + } + return true; + } + + virtual const TBlobBatch &Value(void) const { + return BatchLoader::Value(); + } + + virtual const NDArrayStorageType GetStorageType(bool is_data) const { + return sparse_base_->GetStorageType(is_data); + } + + virtual const TShape GetShape(bool is_data) const { + TShape inst_shape = sparse_base_->GetShape(is_data); + std::vector shape_vec; + shape_vec.push_back(param_.batch_size); + for (index_t dim = 0; dim < inst_shape.ndim(); ++dim) { + shape_vec.push_back(inst_shape[dim]); + } + return TShape(shape_vec.begin(), shape_vec.end()); + } + + private: + /*! \brief base sparse iterator */ + SparseIIterator *sparse_base_; + /*! \brief data instances */ + std::vector inst_cache_; + /*! \brief data storage type */ + NDArrayStorageType data_stype_; + /*! \brief data label type */ + NDArrayStorageType label_stype_; + /*! \brief tensor offset for slicing */ + std::vector offsets_; + + // check whether ith position is the indptr tensor for a CSR tensor + inline bool IsIndPtr(size_t i) { + auto data_num_aux = num_aux_data(data_stype_); + auto label_num_aux = num_aux_data(label_stype_); + auto label_indptr_offset = data_num_aux + 1 + label_num_aux; + // data indptr + if (i == data_num_aux && data_stype_ == kCSRStorage) { + return true; + } + // label indptr + if (i == label_indptr_offset && label_stype_ == kCSRStorage && data_stype_ == kCSRStorage) { + return true; + } + return false; + } + + // initialize the data holder by using from the batch + inline void InitDataFromBatch() { + CHECK(data_stype_ == kCSRStorage || label_stype_ == kCSRStorage); + CHECK_GT(inst_cache_.size(), 0); + out_.data.clear(); + offsets_.clear(); + + size_t total_size = inst_cache_[0].data.size(); + data_.resize(total_size); + offsets_.resize(total_size, 0); + std::vector vec_sizes(total_size, 0); + // accumulate the memory required for a batch + for (size_t i = 0; i < total_size; ++i) { + size_t size = 0; + // vec_size for indptr + if (IsIndPtr(i)) { + size = param_.batch_size + 1; + } else { + for (const auto &d : inst_cache_) size += d.data[i].shape_.Size(); + } + vec_sizes[i] = size; + } + + CHECK_EQ(vec_sizes[0], vec_sizes[1]); + for (size_t i = 0; i < total_size; ++i) { + int src_type_flag = inst_cache_[0].data[i].type_flag_; + // init object attributes + TShape dst_shape(mshadow::Shape1(vec_sizes[i])); + data_[i].resize(mshadow::Shape1(vec_sizes[i]), src_type_flag); + CHECK(data_[i].dptr_ != nullptr); + out_.data.push_back(TBlob(data_[i].dptr_, dst_shape, cpu::kDevMask, src_type_flag)); + } + } +}; // class BatchLoader +} // namespace io +} // namespace mxnet +#endif // MXNET_IO_ITER_SPARSE_BATCHLOADER_H_ diff --git a/src/io/iter_sparse_prefetcher.h b/src/io/iter_sparse_prefetcher.h new file mode 100644 index 000000000000..79b4fa8e2c6c --- /dev/null +++ b/src/io/iter_sparse_prefetcher.h @@ -0,0 +1,135 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file iter_sparse_prefetcher.h + * \brief define a prefetcher using threaditer to keep k batch fetched + */ +#ifndef MXNET_IO_ITER_SPARSE_PREFETCHER_H_ +#define MXNET_IO_ITER_SPARSE_PREFETCHER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "./inst_vector.h" +#include "./image_iter_common.h" +#include "./iter_prefetcher.h" +#include "./iter_sparse.h" + +namespace mxnet { +namespace io { +// iterator on sparse data +class SparsePrefetcherIter : public PrefetcherIter { + public: + explicit SparsePrefetcherIter(SparseIIterator* base) + : PrefetcherIter(base), sparse_loader_(base) {} + + ~SparsePrefetcherIter() {} + + virtual void Init(const std::vector >& kwargs) { + PrefetcherIter::InitParams(kwargs); + // use the kwarg to init batch loader + sparse_loader_->Init(kwargs); + iter.Init([this](DataBatch **dptr) { + if (!sparse_loader_->Next()) return false; + const TBlobBatch& batch = sparse_loader_->Value(); + if (*dptr == nullptr) { + // allocate databatch + *dptr = new DataBatch(); + (*dptr)->num_batch_padd = batch.num_batch_padd; + // (*dptr)->data.at(0) => data + // (*dptr)->data.at(1) => label + (*dptr)->data.resize(2); + (*dptr)->index.resize(batch.batch_size); + size_t data_iter = 0; + for (size_t i = 0; i < (*dptr)->data.size(); ++i) { + bool is_data = i == 0; + auto stype = this->GetStorageType(is_data); + auto dtype = param_.dtype ? param_.dtype.value() : batch.data[data_iter].type_flag_; + if (stype == kDefaultStorage) { + (*dptr)->data.at(i) = NDArray(batch.data[data_iter].shape_, + Context::CPU(), false, dtype); + } else { + (*dptr)->data.at(i) = NDArray(stype, this->GetShape(is_data), + Context::CPU(), false, dtype); + } + data_iter += num_aux_data(stype) + 1; + } + } + // copy data over + size_t data_iter = 0; + for (size_t i = 0; i < (*dptr)->data.size(); ++i) { + auto& nd = ((*dptr)->data)[i]; + auto stype = nd.storage_type(); + auto& data_i = ((*dptr)->data)[i]; + if (stype == kDefaultStorage) { + CopyFromTo(data_i.data(), batch.data[data_iter]); + } else if (stype == kCSRStorage) { + auto& values = batch.data[data_iter]; + auto& indices = batch.data[data_iter + 1]; + auto& indptr = batch.data[data_iter + 2]; + // allocate memory + CHECK_EQ(indices.shape_.Size(), values.shape_.Size()); + nd.CheckAndAllocAuxData(csr::kIdx, indices.shape_); + nd.CheckAndAllocData(values.shape_); + nd.CheckAndAllocAuxData(csr::kIndPtr, indptr.shape_); + // copy values, indices and indptr + CopyFromTo(data_i.data(), values); + CopyFromTo(data_i.aux_data(csr::kIdx), indices); + CopyFromTo(data_i.aux_data(csr::kIndPtr), indptr); + } else { + LOG(FATAL) << "Storage type not implemented: " << stype; + } + data_iter += num_aux_data(stype) + 1; + (*dptr)->num_batch_padd = batch.num_batch_padd; + } + if (batch.inst_index) { + std::copy(batch.inst_index, + batch.inst_index + batch.batch_size, + (*dptr)->index.begin()); + } + return true; + }, + [this]() { sparse_loader_->BeforeFirst(); }); + } + + virtual void BeforeFirst(void) { + PrefetcherIter::BeforeFirst(); + } + + virtual bool Next(void) { + return PrefetcherIter::Next(); + } + virtual const DataBatch &Value(void) const { + return PrefetcherIter::Value(); + } + + virtual const NDArrayStorageType GetStorageType(bool is_data) const { + return sparse_loader_->GetStorageType(is_data); + } + + virtual const TShape GetShape(bool is_data) const { + return sparse_loader_->GetShape(is_data); + } + + private: + /*! \brief internal sparse batch loader */ + SparseIIterator* sparse_loader_; + + inline void CopyFromTo(TBlob dst, const TBlob src) { + MSHADOW_TYPE_SWITCH(src.type_flag_, DType, { + mshadow::Copy(dst.FlatTo1D(), src.FlatTo1D()); + }); + } +}; +} // namespace io +} // namespace mxnet +#endif // MXNET_IO_ITER_SPARSE_PREFETCHER_H_ diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 1197d4ef3edb..a8a5cf3b5268 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -3,13 +3,16 @@ */ #ifndef MXNET_KVSTORE_COMM_H_ #define MXNET_KVSTORE_COMM_H_ +#include #include #include #include #include #include #include +#include #include "mxnet/ndarray.h" +#include "../common/utils.h" namespace mxnet { namespace kvstore { /** @@ -29,9 +32,10 @@ class Comm { } virtual ~Comm() { } /** - * \brief init key with the data shape + * \brief init key with the data shape and storage shape */ - virtual void Init(int key, const TShape& shape, int dtype = mshadow::kFloat32) = 0; + virtual void Init(int key, const NDArrayStorageType stype, + const TShape& shape, int dtype = mshadow::kFloat32) = 0; /** * \brief returns src[0] + .. + src[src.size()-1] */ @@ -64,43 +68,84 @@ class CommCPU : public Comm { CommCPU() { nthread_reduction_ = dmlc::GetEnv("MXNET_KVSTORE_REDUCTION_NTHREADS", 4); bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000); + // TODO(junwu) delete the following data member, now for benchmark only + is_serial_push_ = dmlc::GetEnv("MXNET_KVSTORE_SERIAL_PUSH", 0); } virtual ~CommCPU() { } - void Init(int key, const TShape& shape, int type = mshadow::kFloat32) override { - merge_buf_[key].merged = NDArray(shape, pinned_ctx_, false, type); + void Init(int key, const NDArrayStorageType stype, const TShape& shape, + int type = mshadow::kFloat32) override { + if (stype == kDefaultStorage) { + merge_buf_[key].merged = NDArray(shape, pinned_ctx_, false, type); + } else { + merge_buf_[key].merged = NDArray(stype, shape, pinned_ctx_, true, type); + } } const NDArray& Reduce(int key, const std::vector& src, int priority) override { + auto& buf = merge_buf_[key]; // avoid extra copy for single device, but it may bring problems for // abnormal usage of kvstore if (src.size() == 1) { - return src[0]; + if (src[0].storage_type() == buf.merged.storage_type()) { + return src[0]; + } else { + CopyFromTo(src[0], &buf.merged, priority); + return buf.merged; + } } - std::vector const_vars(src.size() - 1); - std::vector reduce(src.size()); - auto& buf = merge_buf_[key]; - CopyFromTo(src[0], &buf.merged, priority); - reduce[0] = buf.merged; - if (buf.copy_buf.empty()) { - buf.copy_buf.resize(src.size()-1); - for (size_t j = 0; j < src.size() - 1; ++j) { - buf.copy_buf[j] = NDArray( - src[0].shape(), pinned_ctx_, false, src[0].dtype()); + if (buf.merged.storage_type() == kDefaultStorage) { + std::vector const_vars(src.size() - 1); + std::vector reduce(src.size()); + CopyFromTo(src[0], &buf.merged, priority); + reduce[0] = buf.merged; + + if (buf.copy_buf.empty()) { + buf.copy_buf.resize(src.size()-1); + for (size_t j = 0; j < src.size() - 1; ++j) { + // allocate NDArray basd on storage type + buf.copy_buf[j] = NDArray( + src[0].shape(), pinned_ctx_, false, src[0].dtype()); + } } - } - for (size_t i = 1; i < src.size(); ++i) { - CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority); - reduce[i] = buf.copy_buf[i-1]; - const_vars[i-1] = reduce[i].var(); - } + for (size_t i = 1; i < src.size(); ++i) { + CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority); + reduce[i] = buf.copy_buf[i-1]; + const_vars[i-1] = reduce[i].var(); + } + + Engine::Get()->PushSync([reduce, this](RunContext rctx) { + ReduceSumCPU(reduce); + }, Context::CPU(), const_vars, {reduce[0].var()}, + FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); + + } else { + // buf.merged is a sparse ndarray. + std::vector const_vars(src.size()); + std::vector reduce(src.size()); - Engine::Get()->PushSync([reduce, this](RunContext rctx) { - ReduceSumCPU(reduce); - }, Context::CPU(), const_vars, {reduce[0].var()}, - FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); + if (buf.copy_buf.empty()) { + buf.copy_buf.resize(src.size()); + for (size_t j = 0; j < src.size(); ++j) { + buf.copy_buf[j] = NDArray( + src[0].storage_type(), src[0].shape(), pinned_ctx_, true, src[0].dtype()); + } + } + for (size_t i = 0; i < src.size(); ++i) { + CopyFromTo(src[i], &(buf.copy_buf[i]), priority); + reduce[i] = buf.copy_buf[i]; + const_vars[i] = reduce[i].var(); + } + auto result = buf.merged; + Engine::Get()->PushSync([reduce, result, this](RunContext rctx) { + NDArray out = result; + is_serial_push_? + ReduceSumCPUExSerial(reduce, &out) : ReduceSumCPUExParallel(reduce, &out); + }, Context::CPU(), const_vars, {result.var()}, + FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); + } return buf.merged; } @@ -133,6 +178,188 @@ class CommCPU : public Comm { }); } + // serial implementation of reduce sum for row sparse NDArray. + // TODO(haibin) use openmp kernel to parallelize the summation + inline void ReduceSumCPUExSerial(const std::vector &in, NDArray *out) { + using namespace rowsparse; + using namespace mshadow; + auto stype = out->storage_type(); + CHECK_EQ(stype, kRowSparseStorage) << "Unexpected storage type " << stype; + size_t total_num_rows = 0; + size_t num_in = in.size(); + // skip the ones with empty indices and values + std::vector skip(num_in, false); + // the values tensor of the inputs + MSHADOW_TYPE_SWITCH(out->dtype(), DType, { + MSHADOW_INT_TYPE_SWITCH(out->aux_type(kIdx), IType, { + std::vector> in_vals(num_in); + std::vector> in_indices(num_in); + // offset to the values tensor of all inputs + std::vector offsets(num_in, 0); + std::vector num_rows(num_in, 0); + for (size_t i = 0; i < num_in; i++) { + if (!in[i].storage_initialized()) { + skip[i] = true; + continue; + } + auto size = in[i].aux_shape(kIdx).Size(); + num_rows[i] = size; + total_num_rows += size; + in_vals[i] = in[i].data().FlatTo2D(); + in_indices[i] = in[i].aux_data(kIdx).FlatTo1D(); + } + std::vector indices; + indices.reserve(total_num_rows); + // gather indices from all inputs + for (size_t i = 0; i < num_in; i++) { + for (size_t j = 0; j < num_rows[i]; j++) { + indices.emplace_back(in_indices[i][j]); + } + } + CHECK_EQ(indices.size(), total_num_rows); + // dedup indices + std::sort(indices.begin(), indices.end()); + indices.resize(std::unique(indices.begin(), indices.end()) - indices.begin()); + // the one left are unique non-zero rows + size_t nnr = indices.size(); + // allocate memory for output + out->CheckAndAlloc({Shape1(nnr)}); + auto idx_data = out->aux_data(kIdx).FlatTo1D(); + auto val_data = out->data().FlatTo2D(); + + for (size_t i = 0; i < nnr; i++) { + // copy indices back + idx_data[i] = indices[i]; + bool zeros = true; + for (size_t j = 0; j < num_in; j++) { + if (skip[j]) continue; + size_t offset = offsets[j]; + if (offset < num_rows[j]) { + if (indices[i] == in_indices[j][offset]) { + if (zeros) { + Copy(val_data[i], in_vals[j][offset], nullptr); + zeros = false; + } else { + val_data[i] += in_vals[j][offset]; + } + offsets[j] += 1; + } + } + } + } + }); + }); + } + + template + void ReduceSumCPUExImpl(const std::vector& nds, + const std::vector& uniq_row_idx, + NDArray* out) { +#pragma omp parallel num_threads(nthread_reduction_) + { + const size_t nnr = uniq_row_idx.size(); + const int num_threads = omp_get_num_threads(); + size_t row_block_len = (nnr + num_threads - 1) / num_threads; + const size_t row_block_start = omp_get_thread_num() * row_block_len; + if (row_block_start < nnr) { + const size_t row_block_end = std::min(row_block_start+row_block_len, nnr); + + auto out_values = out->data().FlatTo2D(); + auto out_indices = out->aux_data(rowsparse::kIdx).FlatTo1D(); + for (size_t i = row_block_start; i < row_block_end; ++i) { + out_indices[i] = uniq_row_idx[i]; + } + for (const auto& nd : nds) { + if (nd.storage_initialized()) { + const auto nd_indices = nd.aux_data(rowsparse::kIdx).FlatTo1D(); + const auto nd_values = nd.data().FlatTo2D(); + const auto nd_num_rows = nd.aux_shape(rowsparse::kIdx).Size(); + const IType* nd_indices_start = &nd_indices[0]; + const IType* nd_indices_end = nd_indices_start + nd_num_rows; + const IType* row_idx_ptr = std::lower_bound(nd_indices_start, nd_indices_end, + out_indices[row_block_start]); + // skip this nd if all of its row indices are smaller than out_indices[row_block_start] + // or current row block is not covered by [*row_idx_ptr, nd_indices_end). + if (nd_indices_end == row_idx_ptr || *row_idx_ptr > out_indices[row_block_end-1]) { + continue; + } + for (size_t irow = row_block_start; + irow < row_block_end && row_idx_ptr != nd_indices_end;) { + if (out_indices[irow] == *row_idx_ptr) { + auto out_value_cur_row = out_values[irow]; + const auto offset = row_idx_ptr - nd_indices_start; + auto nd_value_cur_row = nd_values[offset]; + for (size_t j = 0; j < nd_value_cur_row.shape_[0]; ++j) { + out_value_cur_row[j] += nd_value_cur_row[j]; + } + ++irow; + ++row_idx_ptr; + } else if (out_indices[irow] < *row_idx_ptr) { + ++irow; + } else { + ++row_idx_ptr; + } + } + } + } + } + } + } + + /*! + * \brief Given a vector of ndarrays, generate a index vector containing + * all the unique row indices of the ndarrays. + */ + template + void GetUniqueRspRowIdx(const std::vector& nds, + std::vector* uniq_row_idx) { + using namespace rowsparse; + size_t total_num_rows = 0; + for (const auto& nd : nds) { + CHECK_EQ(nd.storage_type(), kRowSparseStorage); + if (nd.storage_initialized()) { + total_num_rows += nd.aux_shape(kIdx).Size(); + } + } + + uniq_row_idx->resize(total_num_rows); + int nthreads = omp_get_max_threads(); + int offset = 0; + for (const auto& nd : nds) { + if (nd.storage_initialized()) { + const IType* nd_row_idx = nd.aux_data(kIdx).dptr(); + const int num_rows = nd.aux_shape(kIdx).Size(); +#pragma omp parallel for num_threads(nthreads) + for (int i = 0; i < num_rows; ++i) { + (*uniq_row_idx)[offset+i] = nd_row_idx[i]; + } + offset += num_rows; + } + } + + common::ParallelSort(uniq_row_idx->begin(), uniq_row_idx->end(), nthreads); + auto it = std::unique(uniq_row_idx->begin(), uniq_row_idx->end()); + uniq_row_idx->resize(it - uniq_row_idx->begin()); + } + + void ReduceSumCPUExParallel(const std::vector& nds, NDArray* out) { + if (nds.empty()) return; + using namespace rowsparse; + CHECK_EQ(out->storage_type(), kRowSparseStorage) + << "Expected row sparse storage type (" + << out->storage_type() << " given)"; + + MSHADOW_TYPE_SWITCH(out->dtype(), DType, { + MSHADOW_INT_TYPE_SWITCH(out->aux_type(kIdx), IType, { + std::vector uniq_row_idx; + GetUniqueRspRowIdx(nds, &uniq_row_idx); + out->CheckAndAlloc({mshadow::Shape1(uniq_row_idx.size())}); + out->data().FlatTo2D() = static_cast(0); + ReduceSumCPUExImpl(nds, uniq_row_idx, out); + }); + }); + } + template inline static void ReduceSumCPU( const std::vector &dptr, size_t offset, index_t size) { @@ -198,6 +425,7 @@ class CommCPU : public Comm { std::unordered_map merge_buf_; size_t bigarray_bound_; int nthread_reduction_; + bool is_serial_push_; }; /** @@ -216,8 +444,13 @@ class CommDevice : public Comm { virtual ~CommDevice() { } - void Init(int key, const TShape& shape, int dtype = mshadow::kFloat32) override { - sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype)); + void Init(int key, const NDArrayStorageType stype, const TShape& shape, + int dtype = mshadow::kFloat32) override { + if (stype == kDefaultStorage) { + sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype)); + } else { + LOG(FATAL) << "storage type " << stype << " not implemented for device yet"; + } } const NDArray& Reduce(int key, const std::vector& src, diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc index be5662e8a6db..78d4958096cc 100644 --- a/src/kvstore/kvstore.cc +++ b/src/kvstore/kvstore.cc @@ -7,7 +7,6 @@ #include #include #include "./kvstore_local.h" -// #include "./kvstore_device.h" #if MXNET_USE_DIST_KVSTORE #include "./kvstore_dist.h" #endif // MXNET_USE_DIST_KVSTORE diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index 5f5a0cc67a64..59d9158012ef 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -11,6 +11,7 @@ #include "mxnet/engine.h" #include "ps/ps.h" #include "./kvstore_dist_server.h" +#include "../operator/tensor/init_op.h" #if MKL_EXPERIMENTAL == 1 #include #include "../operator/mkl/mkl_memory-inl.h" @@ -42,6 +43,7 @@ class KVStoreDist : public KVStoreLocal { } } bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000); + row_sparse_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false); } virtual ~KVStoreDist() { @@ -63,7 +65,7 @@ class KVStoreDist : public KVStoreLocal { const std::vector& values) override { CheckUnique(keys); for (size_t i = 0; i < keys.size(); ++i) { - comm_->Init(keys[i], values[i].shape(), values[i].dtype()); + comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype()); } if (get_rank() == 0) { Push_(keys, values, 0, false); @@ -97,36 +99,51 @@ class KVStoreDist : public KVStoreLocal { // use the same array for merging to guarantee that pull always happens // after the previous push on this key auto& recv_buf = comm_buf_[key]; + const auto storage_type = grouped_vals[i][0]->storage_type(); if (recv_buf.is_none()) { // it may happen for the first time a no-rank-0 worker pull the weight. - recv_buf = NDArray( - grouped_vals[i][0]->shape(), pinned_ctx_, false, grouped_vals[i][0]->dtype()); + if (storage_type == kDefaultStorage) { + recv_buf = NDArray(grouped_vals[i][0]->shape(), pinned_ctx_, + false, grouped_vals[i][0]->dtype()); + } else { + recv_buf = NDArray(storage_type, grouped_vals[i][0]->shape(), + pinned_ctx_, true, grouped_vals[i][0]->dtype()); + // initialize the buffer with sufficient memory + op::FillDnsZerosRspImpl(nullptr, &recv_buf); + } } + if (storage_type == kDefaultStorage) { #if MKL_EXPERIMENTAL == 1 - mkl_set_tblob_eager_mode(recv_buf.data()); + mkl_set_tblob_eager_mode(recv_buf.data()); #endif - real_t* data = static_cast(recv_buf.data().dptr_); - size_t size = recv_buf.shape().Size(); - - auto pull_from_servers = [this, key, data, size]( - RunContext rctx, Engine::CallbackOnComplete cb) { - // convert to ps keys - PSKV& pskv = EncodeKey(key, size); - - // issue pull, false means no delete - auto vals = new ps::SArray(data, size, false); - CHECK_NOTNULL(ps_worker_)->ZPull( - pskv.keys, vals, &pskv.lens, 0, [vals, cb](){ delete vals; cb(); }); - }; - - CHECK_NOTNULL(Engine::Get())->PushAsync( - pull_from_servers, - pinned_ctx_, - {}, - {recv_buf.var()}, - FnProperty::kNormal, - priority, - PROFILER_MESSAGE("KVStoreDistPull")); + real_t* data = static_cast(recv_buf.data().dptr_); + size_t size = recv_buf.shape().Size(); + auto pull_from_servers = [this, key, data, size]( + RunContext rctx, Engine::CallbackOnComplete cb) { + // convert to ps keys + PSKV& pskv = EncodeKey(key, size); + + // issue pull, false means no delete + auto vals = new ps::SArray(data, size, false); + CHECK_NOTNULL(ps_worker_)->ZPull( + pskv.keys, vals, &pskv.lens, kDefaultPushPull, [vals, cb](){ delete vals; cb(); }); + }; + + CHECK_NOTNULL(Engine::Get())->PushAsync( + pull_from_servers, + pinned_ctx_, + {}, + {recv_buf.var()}, + FnProperty::kNormal, + priority, + PROFILER_MESSAGE("KVStoreDistDefaultPull")); + } else if (storage_type == kRowSparseStorage) { + recv_buf.WaitToRead(); + grouped_vals[i][0]->WaitToRead(); + PullRowSparse(key, &recv_buf, grouped_vals[i][0]->aux_ndarray(rowsparse::kIdx), priority); + } else { + LOG(FATAL) << "unknown storage type " << storage_type; + } comm_->Broadcast(key, recv_buf, grouped_vals[i], priority); } @@ -204,41 +221,128 @@ class KVStoreDist : public KVStoreLocal { NDArray merged = do_merge ? comm_->Reduce(key, vals, priority) : vals[0]; auto& send_buf = comm_buf_[key]; + const auto storage_type = merged.storage_type(); if (merged.ctx().dev_mask() == cpu::kDevMask) { send_buf = merged; // avoid memory copy } else { if (send_buf.is_none()) { - send_buf = NDArray(merged.shape(), pinned_ctx_, false, merged.dtype()); + if (storage_type == kDefaultStorage) { + send_buf = NDArray(merged.shape(), pinned_ctx_, false, merged.dtype()); + } else { + send_buf = NDArray(storage_type, merged.shape(), pinned_ctx_, true, merged.dtype()); + // initialize the buffer with sufficient memory + op::FillDnsZerosRspImpl(nullptr, &send_buf); + } } CopyFromTo(merged, &send_buf); } // push to servers - send_buf.WaitToRead(); - size_t size = send_buf.shape().Size(); + if (storage_type == kDefaultStorage) { + send_buf.WaitToRead(); + size_t size = send_buf.shape().Size(); +#if MKL_EXPERIMENTAL == 1 + mkl_set_tblob_eager_mode(send_buf.data()); +#endif + real_t* data = static_cast(send_buf.data().dptr_); + auto push_to_servers = + [this, key, data, size](RunContext rctx, Engine::CallbackOnComplete cb) { + // convert to ps keys + PSKV& pskv = EncodeKey(key, size); + // do push. false means no delete + ps::SArray vals(data, size, false); + CHECK_NOTNULL(ps_worker_)->ZPush( + pskv.keys, vals, pskv.lens, 0, [cb]() { cb(); }); + }; + Engine::Get()->PushAsync( + push_to_servers, + pinned_ctx_, + {send_buf.var()}, + {}, + FnProperty::kNormal, + priority, + PROFILER_MESSAGE("KVStoreDistDefaultPush")); + } else if (storage_type == kRowSparseStorage) { + PushRowSparse(key, send_buf, priority); + } else { + LOG(FATAL) << "unknown storage type"; + } + } + } + + // pull row sparse weight into `recv_buf` based on indices given by `indices` + void PullRowSparse(int key, NDArray *recv_buf, const NDArray indices, int priority) { + using namespace rowsparse; + auto pull_from_servers = [this, key, recv_buf, &indices] + (RunContext rctx, Engine::CallbackOnComplete cb) { + // reading aux_shape & aux_data should be inside the engine + size_t num_rows = indices.shape().Size(); + recv_buf->CheckAndAlloc({mshadow::Shape1(num_rows)}); +#if MKL_EXPERIMENTAL == 1 + mkl_set_tblob_eager_mode(recv_buf->data()); +#endif + real_t* data = static_cast(recv_buf->data().dptr_); + const auto offsets = indices.data().dptr(); + const auto unit_len = recv_buf->shape().ProdShape(1, recv_buf->shape().ndim()); + size_t size = num_rows * unit_len; + // convert to ps keys in row sparse format + PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, unit_len); + if (this->row_sparse_verbose_) { + LOG(INFO) << "pull lens: " << pskv.lens << " keys: " << pskv.keys + << " size: " << size; + } + auto vals = new ps::SArray(data, size, false); + CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, kRowSparsePushPull, + [vals, cb]() { delete vals; cb(); }); + }; + CHECK_NOTNULL(Engine::Get())->PushAsync( + pull_from_servers, + pinned_ctx_, + {indices.var()}, + {recv_buf->var()}, + FnProperty::kNormal, + priority, + PROFILER_MESSAGE("KVStoreDistRowSparsePull")); + recv_buf->WaitToRead(); + // copy indices pulled + auto recv_buf_idx = recv_buf->aux_ndarray(kIdx); + CopyFromTo(indices, &recv_buf_idx); + } + + // push row sparse gradient + void PushRowSparse(int key, const NDArray &send_buf, int priority) { + using namespace rowsparse; + auto push_to_servers = [this, key, &send_buf] + (RunContext rctx, Engine::CallbackOnComplete cb) { #if MKL_EXPERIMENTAL == 1 mkl_set_tblob_eager_mode(send_buf.data()); #endif real_t* data = static_cast(send_buf.data().dptr_); - auto push_to_servers = - [this, key, data, size](RunContext rctx, Engine::CallbackOnComplete cb) { - // convert to ps keys - PSKV& pskv = EncodeKey(key, size); - - // do push. false means no delete - ps::SArray vals(data, size, false); - CHECK_NOTNULL(ps_worker_)->ZPush( - pskv.keys, vals, pskv.lens, 0, [cb]() { cb(); }); - }; - Engine::Get()->PushAsync( - push_to_servers, - pinned_ctx_, - {send_buf.var()}, - {}, - FnProperty::kNormal, - priority, - PROFILER_MESSAGE("KVStoreDistPush")); - } + if (!send_buf.storage_initialized()) return; + size_t num_rows = send_buf.aux_shape(kIdx).Size(); + const auto offsets = send_buf.aux_data(kIdx).dptr(); + const auto unit_len = send_buf.shape().ProdShape(1, send_buf.shape().ndim()); + const auto size = num_rows * unit_len; + + // convert to ps keys in row sparse format + PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, unit_len); + if (this->row_sparse_verbose_) { + LOG(INFO) << "push lens: " << pskv.lens << " keys: " << pskv.keys + << " size: " << size; + } + ps::SArray vals(data, size, false); + CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, kRowSparsePushPull, [cb]() { + cb(); + }); + }; + Engine::Get()->PushAsync( + push_to_servers, + pinned_ctx_, + {send_buf.var()}, + {}, + FnProperty::kNormal, + priority, + PROFILER_MESSAGE("KVStoreDistRowSparsePush")); } /** @@ -266,7 +370,7 @@ class KVStoreDist : public KVStoreLocal { std::unordered_map ps_kv_; /** - * \brief serizelize EncodeKey + * \brief serizelize EncodeRowSparseKey and EncodeKey */ std::mutex mu_; @@ -313,6 +417,37 @@ class KVStoreDist : public KVStoreLocal { return pskv; } + inline PSKV& EncodeRowSparseKey(int key, size_t size, int64_t num_rows, + const int64_t *offsets, size_t unit_len) { + mu_.lock(); + PSKV& pskv = ps_kv_[key]; + mu_.unlock(); + pskv.keys.clear(); + pskv.lens.clear(); + // TODO(haibin) cache this information + auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); + int num_servers = krs.size(); + CHECK_GT(num_servers, 0); + + if (size >= bigarray_bound_ && row_sparse_verbose_) { + LOG(INFO) << "WARNING: big row_sparse weight array sharding is not implemented"; + } + // send it to a single random picked server + int server = (key * 9973) % num_servers; + ps::Key master_key = krs[server].begin() + key; + pskv.keys.push_back(master_key); + pskv.lens.push_back(0); + for (int64_t i = 0; i < num_rows; i++) { + ps::Key ps_key = krs[server].begin() + key + offsets[i]; + CHECK_LT(ps_key, krs[server].end()); + pskv.keys.push_back(ps_key); + pskv.lens.push_back(unit_len); + } + pskv.size = size; + return pskv; + } + + /** * \brief for worker to push and pull data */ @@ -327,6 +462,7 @@ class KVStoreDist : public KVStoreLocal { size_t bigarray_bound_; /// \brief send & recver buffer std::unordered_map comm_buf_; + bool row_sparse_verbose_; }; } // namespace kvstore diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h index 02d4a38c2b10..59d2cb705654 100644 --- a/src/kvstore/kvstore_dist_server.h +++ b/src/kvstore/kvstore_dist_server.h @@ -19,6 +19,8 @@ namespace mxnet { namespace kvstore { +static const int kRowSparsePushPull = 1; +static const int kDefaultPushPull = 0; static const int kStopServer = -1; static const int kSyncMode = -2; @@ -92,7 +94,7 @@ class KVStoreDistServer { static_cast(ps_server_)->set_request_handle( std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2)); ps_server_->set_request_handle( - std::bind(&KVStoreDistServer::DataHandle, this, _1, _2, _3)); + std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3)); sync_mode_ = false; } @@ -133,9 +135,162 @@ class KVStoreDistServer { app->Response(recved); } - void DataHandle(const ps::KVMeta& req_meta, - const ps::KVPairs& req_data, - ps::KVServer* server) { + void DataHandleEx(const ps::KVMeta& req_meta, + const ps::KVPairs& req_data, + ps::KVServer* server) { + if (req_meta.cmd == kRowSparsePushPull) { + DataHandleRowSparse(req_meta, req_data, server); + } else { + DataHandleDefault(req_meta, req_data, server); + } + return; + } + + inline void MergeUpdates(const NDArray& recved, int key, + std::unordered_set *change_set) { + auto& merged = merge_buf_[key]; + if (merged.is_none()) { + merged = NDArray(recved.shape(), Context()); + } + if (change_set->find(key) == change_set->end()) { + CopyFromTo(recved, &merged, 0); + } else { + // TODO(haibin) handle row sparse gradient NDArray with `ReduceSumCPUExParallel` + merged += recved; + } + change_set->insert(key); + } + + void DataHandleRowSparse(const ps::KVMeta& req_meta, + const ps::KVPairs& req_data, + ps::KVServer* server) { + int master_key = DecodeKey(req_data.keys[0]); + auto num_rows = req_data.keys.size() - 1; + if (req_meta.push) { + CHECK_EQ(req_data.lens[0], 0); + CHECK_GT(req_data.lens.size(), 0); + auto unit_len = req_data.lens[1]; + CHECK_GT(unit_len, 0); + real_t* data = req_data.vals.data(); + auto& stored = store_[master_key]; + if (stored.is_none()) { + // LOG(INFO) << "initial push: " << master_key << " size = " << num_rows * unit_len; + // initialization + size_t ds[] = {num_rows, (size_t) unit_len}; + TShape dshape(ds, ds + 2); + CHECK_EQ(req_data.vals.size(), num_rows * unit_len); + TBlob recv_blob(data, dshape, cpu::kDevMask); // NOLINT(*) + NDArray recved = NDArray(recv_blob, 0); + stored = NDArray(dshape, Context()); + CopyFromTo(recved, &stored, 0); + stored.WaitToRead(); + server->Response(req_meta); + return; + } + // synced push + if (sync_mode_) { + // LOG(INFO) << "sync push: " << master_key; + size_t offset = 0; + auto& stored = store_[master_key]; + // merge updates + auto& request_buf = request_buf_[master_key]; + for (size_t i = 1; i <= num_rows; i++) { + // TODO(haibin) decode once and cache result + int key = DecodeKey(req_data.keys[i]); + auto len = req_data.lens[i]; + size_t ds[] = {(size_t)len}; + TShape dshape(ds, ds + 1); + TBlob recv_blob(data, // NOLINT(*) + dshape, cpu::kDevMask); + NDArray recved = NDArray(recv_blob, 0); + MergeUpdates(recved, key, &request_buf.change_set); + offset += len; + } + // perform updates + request_buf.requests.push_back(req_meta); + if (request_buf.requests.size() == (size_t) ps::NumWorkers()) { + // let the main thread to execute updater_, which is necessary for python + for (auto key : request_buf.change_set) { + // slice a row + auto row_id = key - master_key; + NDArray slice = stored.At(row_id); + NDArray update = merge_buf_[key]; + if (updater_) { + exec_.Exec([this, key, &update, &slice](){ + CHECK(updater_); + updater_(key, update, &slice); + }); + } else { + // if no updater, just copy + CopyFromTo(update, &slice); + } + slice.WaitToRead(); + } + request_buf.change_set.clear(); + // LOG(INFO) << "RESPONSE SYNC to " << request_buf.requests.size() << " clients"; + for (const auto& req : request_buf.requests) { + server->Response(req); + } + request_buf.requests.clear(); + } else { + for (size_t i = 1; i <= num_rows; i++) { + int key = DecodeKey(req_data.keys[i]); + merge_buf_[key].WaitToRead(); + } + } + } else { + // async push + auto& stored = store_[master_key]; + for (size_t i = 1; i <= num_rows; i++) { + int key = DecodeKey(req_data.keys[i]); + auto row_id = key - master_key; + auto len = req_data.lens[i]; + size_t ds[] = {(size_t)len}; + TShape dshape(ds, ds + 1); + TBlob recv_blob(data, // NOLINT(*) + dshape, cpu::kDevMask); + NDArray recved = NDArray(recv_blob, 0); + NDArray slice = stored.At(row_id); + exec_.Exec([this, key, &recved, &slice](){ + CHECK(updater_); + updater_(key, recved, &slice); + }); + } + server->Response(req_meta); + stored.WaitToRead(); + } + } else { + // pull + ps::KVPairs response; + auto& stored = store_[master_key]; + CHECK(!stored.is_none()) << "init " << master_key << " first"; + auto shape = stored.shape(); + auto unit_len = shape.ProdShape(1, shape.ndim()); + const float* data = stored.data().dptr(); + auto len = unit_len * num_rows; + // LOG(INFO) << "received pull: " << len; + // concat response values + response.vals.resize(len); + for (size_t i = 1; i <= num_rows; i++) { + int key = DecodeKey(req_data.keys[i]); + const auto src = data + key * unit_len; + auto begin = (i - 1) * unit_len; + auto end = i * unit_len; + response.vals.segment(begin, end).CopyFrom(src, unit_len); + } + // setup response + response.keys = req_data.keys; + std::vector lens(req_data.keys.size(), unit_len); + lens[0] = 0; + response.lens.CopyFrom(lens.begin(), lens.end()); + server->Response(req_meta, response); + } + } + + void DataHandleDefault(const ps::KVMeta& req_meta, + const ps::KVPairs &req_data, + ps::KVServer* server) { + CHECK_EQ(req_meta.cmd, kDefaultPushPull); // do some check CHECK_EQ(req_data.keys.size(), (size_t)1); if (req_meta.push) { @@ -164,37 +319,29 @@ class KVStoreDistServer { } else if (sync_mode_) { // synced push auto& merged = merge_buf_[key]; - if (merged.array.is_none()) { - merged.array = NDArray(dshape, Context()); - } - - if (merged.request.size() == 0) { - CopyFromTo(recved, &merged.array, 0); - } else { - merged.array += recved; - } - - merged.request.push_back(req_meta); - - if (merged.request.size() == (size_t)ps::NumWorkers()) { - // let the main thread to execute updater_, which is necessary for - // python + auto& request_buf = request_buf_[key]; + MergeUpdates(recved, key, &request_buf.change_set); + request_buf.requests.push_back(req_meta); + if (request_buf.requests.size() == (size_t) ps::NumWorkers()) { + CHECK_EQ(request_buf.change_set.size(), 1); + // let the main thread to execute updater_, which is necessary for python if (updater_) { exec_.Exec([this, key, &merged, &stored](){ CHECK(updater_); - updater_(key, merged.array, &stored); + updater_(key, merged, &stored); }); } else { // if no updater, just copy - CopyFromTo(merged.array, &stored); + CopyFromTo(merged, &stored); } - for (const auto& req : merged.request) { + request_buf.change_set.clear(); + for (const auto& req : request_buf.requests) { server->Response(req); } - merged.request.clear(); + request_buf.requests.clear(); stored.WaitToRead(); } else { - merged.array.WaitToRead(); + merged.WaitToRead(); } } else { // async push @@ -209,7 +356,7 @@ class KVStoreDistServer { // pull ps::KVPairs response; CHECK(!stored.is_none()) << "init " << key << " first"; - int len = stored.shape()[0]; + auto len = stored.shape().Size(); response.keys = req_data.keys; response.lens = {len}; // TODO(mli) try to remove this CopyFrom @@ -232,11 +379,14 @@ class KVStoreDistServer { std::unordered_map store_; - struct MergeBuf { - std::vector request; - NDArray array; + struct RequestBuf { + std::vector requests; + std::unordered_set change_set; }; - std::unordered_map merge_buf_; + + std::unordered_map merge_buf_; + std::unordered_map request_buf_; + Executor exec_; diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index caa57a20d46e..e159dd42e596 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include "./comm.h" @@ -43,10 +44,24 @@ class KVStoreLocal : public KVStore { CHECK(local_.find(keys[i]) == local_.end()) << "duplicate init of key " << keys[i]; local_[keys[i]] = values[i].Copy(pinned_ctx_); - comm_->Init(keys[i], values[i].shape(), values[i].dtype()); + comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype()); } } + void Init(const std::vector& str_keys, + const std::vector& values) override { + std::vector keys(str_keys.size()); + for (size_t i = 0; i < str_keys.size(); ++i) { + auto &str_key = str_keys[i]; + CHECK(str_key_dict_.find(str_key) == str_key_dict_.end()) + << "duplicate init of key " << str_key; + auto key = next_str_key_++; + str_key_dict_[str_key] = key; + keys[i] = key; + } + Init(keys, values); + } + void Push(const std::vector& keys, const std::vector& values, int priority) override { @@ -67,7 +82,11 @@ class KVStoreLocal : public KVStore { } updater_(key, merged, &local); } else { - local = merged; + if (merged.storage_type() != local.storage_type()) { + local = merged.Copy(local.ctx()); + } else { + local = merged; + } } } } @@ -87,6 +106,22 @@ class KVStoreLocal : public KVStore { } } + void Push(const std::vector& str_keys, + const std::vector& values, + int priority) override { + std::vector keys(str_keys.size()); + LookupKeys(str_keys, &keys); + Push(keys, values, priority); + } + + void Pull(const std::vector& str_keys, + const std::vector& values, + int priority) override { + std::vector keys(str_keys.size()); + LookupKeys(str_keys, &keys); + Pull(keys, values, priority); + } + protected: /** * \brief group values on keys @@ -118,12 +153,27 @@ class KVStoreLocal : public KVStore { } } } + + void LookupKeys(const std::vector& str_keys, + std::vector *keys) { + for (size_t i = 0; i < str_keys.size(); ++i) { + auto &str_key = str_keys[i]; + CHECK(str_key_dict_.find(str_key) != str_key_dict_.end()) + << "key " << str_key << " doesn't exist. Did you init?"; + keys->at(i) = str_key_dict_[str_key]; + } + } + /// reducer and broadcaster Comm* comm_; /// pinned context Context pinned_ctx_; /// \brief buffer for storing local values std::unordered_map local_; + /// key mapping for string -> integer + std::unordered_map str_key_dict_; + /// the next available integer for string->int key mapping + int next_str_key_ = 0; }; } // namespace kvstore } // namespace mxnet diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 6f1795d6f368..c894f27c25b7 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -12,6 +12,9 @@ #include #include #include "./ndarray_function.h" +#include "../common/utils.h" +#include "../operator/tensor/matrix_op-inl.h" +#include "../operator/tensor/init_op.h" #include "./autograd.h" #if MXNET_USE_OPENCV @@ -26,6 +29,8 @@ namespace mxnet { NDArray NDArray::Reshape(const TShape &shape) const { using namespace autograd; + CHECK(storage_type() == kDefaultStorage) << "Reshape for storage type " << + storage_type() << " is not implemented yet"; if (AutogradRuntime::Get()->IsTraining()) { CHECK_GE(shape_.Size(), shape.Size()) << "NDArray.Reshape: target shape must have must have the same size as " @@ -56,40 +61,101 @@ NDArray NDArray::Reshape(const TShape &shape) const { } } - NDArray NDArray::Slice(index_t begin, index_t end) const { using namespace autograd; - NDArray ret = *this; + using namespace mshadow; CHECK(!is_none()) << "NDArray is not initialized"; CHECK_GE(shape_[0], end) << "Slice end index out of range"; - size_t length = shape_.ProdShape(1, shape_.ndim()); - MSHADOW_TYPE_SWITCH(ret.dtype(), DType, { - ret.byte_offset_ += begin * length * sizeof(DType); - }); - ret.shape_[0] = end - begin; - if (AutogradRuntime::Get()->IsTraining()) { - // fake a slice_axis op - ret.entry_.clear(); - const nnvm::Op* op = nnvm::Op::Get("slice_axis"); - nnvm::NodeAttrs attrs; - attrs.op = op; - attrs.dict.insert({"axis", "0"}); - attrs.dict.insert({"begin", std::to_string(begin)}); - attrs.dict.insert({"end", std::to_string(end)}); - op->attr_parser(&attrs); - std::vector inputs, outputs; - inputs.emplace_back(*this); - outputs.emplace_back(std::move(ret)); - AutogradRuntime::Get()->RecordImperativeFCompute( - op, attrs, &inputs, &outputs); - return outputs[0]; - } else { + CHECK_NE(storage_type(), kUndefinedStorage); + if (storage_type() == kDefaultStorage) { + NDArray ret = *this; + auto stype = storage_type(); + size_t length = shape_.ProdShape(1, shape_.ndim()); + MSHADOW_TYPE_SWITCH(ret.dtype(), DType, { + ret.byte_offset_ += begin * length * sizeof(DType); + }); + ret.shape_[0] = end - begin; + if (AutogradRuntime::Get()->IsTraining()) { + // fake a slice_axis op + ret.entry_.clear(); + const nnvm::Op* op = nnvm::Op::Get("slice_axis"); + nnvm::NodeAttrs attrs; + attrs.op = op; + attrs.dict.insert({"axis", "0"}); + attrs.dict.insert({"begin", std::to_string(begin)}); + attrs.dict.insert({"end", std::to_string(end)}); + op->attr_parser(&attrs); + std::vector inputs, outputs; + inputs.emplace_back(*this); + outputs.emplace_back(std::move(ret)); + AutogradRuntime::Get()->RecordImperativeFCompute( + op, attrs, &inputs, &outputs); + return outputs[0]; + } else { + return ret; + } + } else if (storage_type() == kCSRStorage) { + // TODO(haibin) support auto_grad + TShape sliced_shape(Shape2(end-begin, shape()[1])); + using namespace csr; + NDArray ret(storage_type(), TShape(Shape2(end-begin, shape()[1])), + ctx(), true, dtype_, ptr_->aux_types, + {TShape(Shape1(0)), TShape(Shape1(0))}); + NDArray src = *this; + // destination NDArray shares the same variable + ret.ptr_->var = var(); + + Engine::Get()->PushSync([src, ret, begin, end](RunContext ctx) { + NDArray dst = ret; + // create a new chunk for dst NDArray + NDArray::Chunk chunk = *src.ptr_; + // void indptr storage handle + chunk.aux_handles[kIndPtr] = Storage::Handle(); + // shape for indptr is end - begin + 1 + chunk.CheckAndAllocAuxData(kIndPtr, Shape1(end - begin + 1)); + if (src.ctx().dev_mask() == cpu::kDevMask) { + MSHADOW_INT_TYPE_SWITCH(src.aux_type(kIndPtr), IType, { + MSHADOW_TYPE_SWITCH(src.dtype(), DType, { + // create new indptr + const IType* src_indptr = src.aux_data(kIndPtr).dptr(); + IType* dst_indptr = static_cast (chunk.aux_handles[kIndPtr].dptr); + op::SliceCsrIndPtrImpl(begin, end, ctx, src_indptr, dst_indptr); + // advance idx and values pointers (CPU implementation) + // TODO(haibin) refactor for GPU implementation later + IType offset = src_indptr[begin]; + IType* idx = static_cast(chunk.aux_handles[kIdx].dptr); + DType* values = static_cast(chunk.shandle.dptr); + chunk.aux_handles[kIdx].dptr = idx + offset; + chunk.shandle.dptr = values + offset; + // update storage shape and aux shape (CPU implementation) + auto nnz = dst_indptr[end - begin]; + chunk.aux_shapes[kIdx] = Shape1(nnz); + chunk.storage_shape = Shape1(nnz); + chunk.static_data = true; + chunk.skip_delete_var = true; + // update dst chunk + *dst.ptr_ = chunk; + }); + }); + } else { +#if MXNET_USE_CUDA + LOG(FATAL) << "SliceEx CSR not implemented yet"; +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } + }, ctx(), {}, {var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); return ret; + } else { + LOG(FATAL) << "Slice not yet implemented for storage " << storage_type(); } + return NDArray(); } - NDArray NDArray::At(index_t idx) const { + CHECK(storage_type() == kDefaultStorage) << "Storage type " + << storage_type() << " doesn't support At()"; NDArray ret = this->Slice(idx, idx+1); if (shape_.ndim() > 1) { return ret.Reshape(TShape(shape_.data()+1, shape_.data()+shape_.ndim())); @@ -212,11 +278,11 @@ void BinaryOp(const NDArray &lhs, // redirect everything to mshadow operations switch (lhs.ctx().dev_mask()) { case cpu::kDevMask: { - Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); - }, lhs.ctx(), const_vars, {ret.var()}, - FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { + TBlob tmp = ret.data(); + ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); + }, lhs.ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #if MXNET_USE_CUDA @@ -242,6 +308,7 @@ void SetValueOp(const real_t &rhs, NDArray *out) { switch (ret.ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([rhs, ret](RunContext ctx) { + CHECK(ret.storage_type() == kDefaultStorage); TBlob tmp = ret.data(); ndarray::Eval(rhs, &tmp, ctx); }, ret.ctx(), {}, {ret.var()}, @@ -313,6 +380,112 @@ void ScalarOp(const NDArray &lhs, } } +size_t num_aux_data(NDArrayStorageType stype) { + size_t num = 0; + switch (stype) { + case kDefaultStorage: num = 0; break; + case kCSRStorage: num = 2; break; + case kRowSparseStorage: num = 1; break; + default: LOG(FATAL) << "Unknown storage type" << stype; break; + } + return num; +} + +// Make a copy of a CSR NDArray +template +inline void CopyFromToCsrImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace mshadow; + CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + // if source storage is not initialized, fill destination with zeros + auto s = ctx.get_stream(); + if (!from.storage_initialized()) { + op::FillZerosCsrImpl(s, to); + return; + } + // Allocate storage + to->CheckAndAllocAuxData(csr::kIndPtr, from.aux_shape(csr::kIndPtr)); + to->CheckAndAllocAuxData(csr::kIdx, from.aux_shape(csr::kIdx)); + to->CheckAndAllocData(from.aux_shape(csr::kIdx)); + TBlob val = to->data(); + TBlob indptr = to->aux_data(csr::kIndPtr); + TBlob idx = to->aux_data(csr::kIdx); + ndarray::Copy(from.data(), &val, + from.ctx(), to->ctx(), ctx); + ndarray::Copy(from.aux_data(csr::kIndPtr), &indptr, + from.ctx(), to->ctx(), ctx); + ndarray::Copy(from.aux_data(csr::kIdx), &idx, + from.ctx(), to->ctx(), ctx); +} + +// Make a copy of a row-sparse NDArray +template +inline void CopyFromToRspImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace mshadow; + CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + // if source is zeros, fill destination with zeros, too + auto s = ctx.get_stream(); + if (!from.storage_initialized()) { + op::FillZerosRspImpl(s, to); + return; + } + auto aux_shape = from.aux_shape(rowsparse::kIdx); + to->CheckAndAlloc({aux_shape}); + TBlob val = to->data(); + TBlob idx = to->aux_data(rowsparse::kIdx); + ndarray::Copy(from.data(), &val, + from.ctx(), to->ctx(), ctx); + ndarray::Copy(from.aux_data(rowsparse::kIdx), &idx, + from.ctx(), to->ctx(), ctx); +} + +// Make a copy of a dense NDArray +template +inline void CopyFromToDnsImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace mshadow; + CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + TBlob tmp = to->data(); + ndarray::Copy(from.data(), &tmp, + from.ctx(), to->ctx(), ctx); +} + +// Make a copy of an NDArray based on storage type +template +void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace std; + using namespace mshadow; + // if storage type doesn't match, cast the storage first + auto from_stype = from.storage_type(); + auto to_stype = to->storage_type(); + NDArray casted_nd; + if (from_stype != to_stype) { + TShape shape = from.shape(); + auto from_ctx = from.ctx(); + auto s = ctx.get_stream(); + // TODO(haibin) inplace conversion + if (to_stype == kDefaultStorage) { + casted_nd = NDArray(shape, from_ctx); + } else { + casted_nd = NDArray(to_stype, shape, from_ctx); + } + common::CastStorageDispatch(s, from, casted_nd); + } else { + casted_nd = from; + } + if (to_stype == kDefaultStorage) { + CopyFromToDnsImpl(casted_nd, to, ctx); + } else if (to_stype == kRowSparseStorage) { + CopyFromToRspImpl(casted_nd, to, ctx); + } else if (to_stype == kCSRStorage) { + CopyFromToCsrImpl(casted_nd, to, ctx); + } else { + LOG(FATAL) << "unknown storage type" << to_stype; + } + if (is_same::value || is_same::value) { + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + } +} + void CopyFromTo(const NDArray &from, NDArray *to, int priority) { if (from.var() == to->var()) { // skip to copy to itself @@ -327,44 +500,33 @@ void CopyFromTo(const NDArray &from, NDArray *to, int priority) { NDArray ret = *to; int a = from.ctx().dev_mask(); int b = to->ctx().dev_mask(); - std::vector const_vars; if (from.var() != ret.var()) const_vars.push_back(from.var()); if (a == cpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, from.ctx(), const_vars, {ret.var()}, FnProperty::kNormal, priority, PROFILER_MESSAGE("CopyCPU2CPU")); } else { #if MXNET_USE_CUDA if (a == cpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, ret.ctx(), const_vars, {ret.var()}, FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("CopyCPU2GPU")); } else if (a == gpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, from.ctx(), const_vars, {ret.var()}, FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE("CopyGPU2CPU")); } else if (a == gpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, from.ctx(), const_vars, {ret.var()}, from.dtype() != ret.dtype() ? FnProperty::kNormal : FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE("CopyGPU2GPU")); @@ -638,34 +800,76 @@ NDArray &NDArray::operator/=(const real_t &src) { /* magic number for ndarray version 1, with int64_t TShape */ static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8; +/* magic number for ndarray version 2, with storage type */ +static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9; + void NDArray::Save(dmlc::Stream *strm) const { - strm->Write(NDARRAY_V1_MAGIC); + // write magic number to mark this version + // for storage type + strm->Write(NDARRAY_V2_MAGIC); + + // save storage type + int32_t stype = storage_type(); + strm->Write(&stype, sizeof(stype)); + + const int32_t nad = num_aux_data(storage_type()); + // save storage shape if ndarray is sparse + if (nad > 0) { + storage_shape().Save(strm); + } + + // save shape shape_.Save(strm); if (is_none()) return; + // save context Context ctx = this->ctx(); ctx.Save(strm); TBlob save_data; - NDArray temp; + NDArray nd_cpu; // a copy of *this on cpu if (ctx.dev_mask() != cpu::kDevMask) { - temp = this->Copy(Context::CPU()); - temp.WaitToRead(); - save_data = temp.data(); + nd_cpu = this->Copy(Context::CPU()); + nd_cpu.WaitToRead(); + save_data = nd_cpu.data(); } else { this->WaitToRead(); save_data = this->data(); + nd_cpu = *this; } + // save type flag int32_t type_flag = save_data.type_flag_; strm->Write(&type_flag, sizeof(type_flag)); + + // save aux_types and aux_shapes + if (nad > 0) { + for (int i = 0; i < nad; ++i) { + int32_t aux_type_flag = aux_type(i); + strm->Write(&aux_type_flag, sizeof(aux_type_flag)); + aux_shape(i).Save(strm); + } + } + + // save data CHECK(save_data.CheckContiguous()); size_t type_size = mshadow::mshadow_sizeof(type_flag); - strm->Write(save_data.dptr_, type_size * shape_.Size()); + // save data could be values of sparse tensors + // must use save_data.shape_ instead of this->shape_ + strm->Write(save_data.dptr_, type_size * save_data.shape_.Size()); + + // save aux data + if (nad > 0) { + for (int i = 0; i < nad; ++i) { + TBlob save_data = nd_cpu.aux_data(i); + // save aux_data + CHECK(save_data.CheckContiguous()); + size_t aux_type_size = mshadow::mshadow_sizeof(aux_type(i)); + strm->Write(save_data.dptr_, aux_type_size * save_data.Size()); + } + } } -bool LegacyTShapeLoad(dmlc::Stream *strm, TShape *shape) { - uint32_t magic; - if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false; +bool LegacyTShapeLoad(dmlc::Stream *strm, TShape *shape, const uint32_t magic) { switch (magic) { case NDARRAY_V1_MAGIC: return shape->Load(strm); @@ -681,10 +885,10 @@ bool LegacyTShapeLoad(dmlc::Stream *strm, TShape *shape) { } } -bool NDArray::Load(dmlc::Stream *strm) { +bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) { // load shape TShape shape; - if (!LegacyTShapeLoad(strm, &shape)) return false; + if (!LegacyTShapeLoad(strm, &shape, magic)) return false; if (shape.ndim() == 0) { *this = NDArray(); return true; } @@ -712,6 +916,88 @@ bool NDArray::Load(dmlc::Stream *strm) { } } +bool NDArray::Load(dmlc::Stream *strm) { + uint32_t magic; + if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false; + if (magic != NDARRAY_V2_MAGIC) { + return LegacyLoad(strm, magic); + } + + // load storage type + int32_t stype; + if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false; + const int32_t nad = num_aux_data(static_cast(stype)); + + // load storage shape + TShape sshape; + if (nad > 0) { + if (!sshape.Load(strm)) return false; + } + + // load shape + TShape shape; + if (!shape.Load(strm)) return false; + if (shape.ndim() == 0) { + *this = NDArray(); return true; + } + + // load context + Context ctx; + if (!ctx.Load(strm)) return false; + + // load type flag + int32_t type_flag; + if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false; + + // load aux_types and aux_shapes + std::vector aux_types; + std::vector aux_shapes; + if (nad > 0) { + aux_types.resize(nad); + aux_shapes.resize(nad); + for (int i = 0; i < nad; ++i) { + // load aux_type(i) + if (strm->Read(&aux_types[i], sizeof(aux_types[i])) != sizeof(aux_types[i])) return false; + // load aux_shapes(i) + if (!aux_shapes[i].Load(strm)) return false; + } + } + + // load data into CPU + NDArray temp; + if (0 == nad) { + temp = NDArray(shape, Context::CPU(), false, type_flag); + } else { + temp = NDArray(static_cast(stype), shape, + Context::CPU(), false, type_flag, + aux_types, aux_shapes, sshape); + } + // load data + TBlob load_data = temp.data(); + size_t type_size = mshadow::mshadow_sizeof(type_flag); + size_t nread = type_size * load_data.Size(); + if (strm->Read(load_data.dptr_, nread) != nread) return false; + + // load aux_data + if (nad > 0) { + for (int i = 0; i < nad; ++i) { + load_data = temp.aux_data(i); + type_size = mshadow::mshadow_sizeof(load_data.type_flag_); + nread = type_size * load_data.Size(); + if (strm->Read(load_data.dptr_, nread) != nread) return false; + } + } + + if (ctx.dev_mask() == cpu::kDevMask) { + *this = std::move(temp); return true; + } else { +#if MXNET_USE_CUDA + *this = temp.Copy(ctx); return true; +#else + *this = std::move(temp); return true; +#endif + } +} const uint64_t kMXAPINDArrayListMagic = 0x112; @@ -744,7 +1030,16 @@ void NDArray::Load(dmlc::Stream* fi, } NDArray NDArray::Copy(Context ctx) const { - NDArray ret(shape(), ctx, true, dtype_); + NDArray ret; + if (kDefaultStorage == storage_type()) { + ret = NDArray(shape(), ctx, true, dtype_); + } else if (kUndefinedStorage != storage_type()) { + ret = NDArray(storage_type(), shape(), ctx, true, dtype_, + ptr_->aux_types, ptr_->aux_shapes, storage_shape()); + } else { + LOG(FATAL) << "NDArray::Copy cannot copy undefined storage-type ndarray to ctx.dev_type=" + << ctx.dev_type << ", ctx.dev_id=" << ctx.dev_id; + } CopyFromTo(*this, &ret); return ret; } diff --git a/src/ndarray/ndarray_function-inl.h b/src/ndarray/ndarray_function-inl.h index 28524b73d0dd..aad80fd4360a 100644 --- a/src/ndarray/ndarray_function-inl.h +++ b/src/ndarray/ndarray_function-inl.h @@ -12,27 +12,28 @@ // macro to help specialize evaluation function #ifndef DECL_TERNARY -#define DECL_TERNARY(XPU, OP, FUN) \ - template<> \ - void Eval(const TBlob &lhs, const TBlob &mhs, \ - const TBlob &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, mhs, rhs, ret, ctx); \ +#define DECL_TERNARY(XPU, OP, FUN) \ + template<> \ + void Eval(const TBlob &lhs, const TBlob &mhs, \ + const TBlob &rhs, TBlob *ret, RunContext ctx) { \ + FUN(lhs, mhs, rhs, ret, ctx); \ } #endif #ifndef DECL_BINARY -#define DECL_BINARY(XPU, OP, FUN) \ - template<> \ +#define DECL_BINARY(XPU, OP, FUN) \ + template<> \ void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, rhs, ret, ctx); \ + FUN(lhs, rhs, ret, ctx); \ } #endif #ifndef DECL_SCALAR -#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ - template<> \ - void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, rhs, ret, ctx); \ +#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ + template<> \ + void Eval(const TBlob &lhs, const real_t &rhs, \ + TBlob *ret, RunContext ctx) { \ + FUN(lhs, rhs, ret, ctx); \ } #endif @@ -44,10 +45,11 @@ namespace mxnet { namespace ndarray { + // true implementation template -inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalBinary_(const TBlob &lhs, const TBlob &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); CHECK_EQ(ret->type_flag_, lhs.type_flag_) @@ -61,10 +63,9 @@ inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs, }); } - template -inline void EvalOneHot_(const TBlob &index, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalOneHot_(const TBlob &index, const TBlob &rhs, + TBlob *ret, RunContext ctx) { LOG(INFO) << "The operator onehot_encode is deprecated; use one_hot instead."; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); @@ -81,8 +82,8 @@ inline void EvalOneHot_(const TBlob &index, const TBlob &rhs, } template -inline void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); // TODO(eric): support mixed type choose, i.e. int index and float rhs. @@ -98,8 +99,8 @@ inline void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, } template -inline void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); ret->get(s) @@ -109,8 +110,8 @@ inline void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob } template -inline void EvalScalar_(const TBlob &lhs, const real_t &rhs, - TBlob *ret, RunContext ctx) { +void EvalScalar_(const TBlob &lhs, const real_t &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); CHECK_EQ(ret->type_flag_, lhs.type_flag_) @@ -130,7 +131,7 @@ inline void EvalScalar_(const TBlob &lhs, const real_t &rhs, template<> void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max, - TBlob *ret, RunContext ctx) { + TBlob *ret, RunContext ctx) { typedef DEVICE xpu; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); @@ -145,12 +146,11 @@ void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max } template<> -void EvalRandom( - const real_t &a, - const real_t &b, - const Resource &resource, - TBlob *ret, - RunContext ctx) { +void EvalRandom(const real_t &a, + const real_t &b, + const Resource &resource, + TBlob *ret, + RunContext ctx) { typedef DEVICE xpu; mshadow::Stream *s = ctx.get_stream(); switch (ret->type_flag_) { @@ -426,6 +426,7 @@ DECL_SCALAR(DEVICE, Plus, EvalScalar_, true) DECL_SCALAR(DEVICE, Minus, EvalScalar_, true) DECL_SCALAR(DEVICE, Mul, EvalScalar_, true) DECL_SCALAR(DEVICE, Div, EvalScalar_, true) + // for reverse seq DECL_SCALAR(DEVICE, Plus, EvalScalar_, false) DECL_SCALAR(DEVICE, Minus, EvalScalar_, false) diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index def38126d08c..3f2000f6ee99 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -53,6 +53,42 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs, return true; } +// Only inferring output storage types from input for now +template +inline bool ElemwiseStorageAttr(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + auto deduce = [&](std::vector *vec, const char *name, AttrType& result, + bool fallback) { + auto &v = *vec; + for (size_t i = 0; i < vec->size(); ++i) { + if (v[i] == kUndefinedStorage) { + // if input type is unknown, assume it's default storage + CHECK(assign(&v[i], kDefaultStorage)); + } else if (assign(&result, v[i]) == false && fallback) { + result = kDefaultStorage; + } + } + }; + AttrType dattr = kUndefinedStorage; + deduce(in_attrs, "input", dattr, enable_fallback); + if (reverse_infer) { + LOG(FATAL) << "not implemented yet"; + } + auto write = [&](std::vector *vec, const char *name) { + for (size_t i = 0; i < vec->size(); ++i) { + CHECK(assign(&(*vec)[i], dattr)) + << "Incompatible attr in node " << attrs.name << " at " << i << "-th " + << name << ": " << "expected " << dattr << ", got " << (*vec)[i]; + } + }; + if (is_none(dattr)) dattr = kDefaultStorage; + write(out_attrs, "output"); + return true; +} + template inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, @@ -73,6 +109,29 @@ inline bool ElemwiseType(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs, -1); } +template +inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + return ElemwiseStorageAttr( + attrs, in_attrs, out_attrs); +} + +inline bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), static_cast(2)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(1)) << " in operator " << attrs.name; + auto &in = *in_attrs; + auto &out = *out_attrs; + CHECK_NE(in[1], kUndefinedStorage) << "rhs storage type must be known"; + if (in[0] == kUndefinedStorage) in[0] = in[1]; + if (out[0] == kUndefinedStorage) out[0] = in[1]; + return true; +} + // Transfer gradient and input to FGradient function struct ElemwiseGradUseIn { const char *op_name; @@ -105,6 +164,22 @@ struct ElemwiseGradUseNone { } }; +// TODO(haibin) this is a temporary function for debugging purpose. Remove later. +template +void print_info(const mshadow::Tensor& tensor, const std::string& name) { + std::cout << "Tensor " << name << " with shape ("; + int len = 1; + for (int i = 0; i < dim; i++) { + len *= tensor.shape_[i]; + std::cout << tensor.shape_[i] << ","; + if (i == dim - 1) std::cout << ")"; + } + std::cout << std::endl; + for (int j = 0; j < len; j ++) std::cout << tensor.dptr_[j] << " "; + std::cout << std::endl; +} + + } // namespace op } // namespace mxnet diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 9b5dcfe3d3b1..d4a473c8be0c 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -7,6 +7,7 @@ #ifndef MXNET_OPERATOR_MXNET_OP_H_ #define MXNET_OPERATOR_MXNET_OP_H_ +#include #include #include @@ -22,6 +23,8 @@ const float PI = 3.14159265358979323846; using std::isnan; #endif +template +int get_num_threads(const int N); #ifdef __CUDACC__ #define CUDA_KERNEL_LOOP(i, n) \ @@ -37,8 +40,18 @@ inline int cuda_get_num_blocks(const int N) { using namespace mshadow::cuda; return std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum); } + +template<> +inline int get_num_threads(const int N) { + using namespace mshadow::cuda; + return kBaseThreadNum * cuda_get_num_blocks(N); +} #endif // __CUDACC__ +template<> +inline int get_num_threads(const int N) { + return omp_get_max_threads(); +} /*! \brief operator request type switch */ #define MXNET_ASSIGN_REQ_SWITCH(req, ReqType, ...) \ @@ -198,7 +211,6 @@ __global__ void mxnet_generic_kernel(int N, Args... args) { } } - template struct Kernel { template diff --git a/src/operator/nn/cast_storage-inl.cuh b/src/operator/nn/cast_storage-inl.cuh new file mode 100644 index 000000000000..b99d875eb612 --- /dev/null +++ b/src/operator/nn/cast_storage-inl.cuh @@ -0,0 +1,26 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file cast_storage-inl.cuh + * \brief implementation of cast_storage op on GPU + */ +#ifndef MXNET_OPERATOR_NN_CAST_STORAGE_INL_CUH_ +#define MXNET_OPERATOR_NN_CAST_STORAGE_INL_CUH_ + +#include +#include + +namespace mxnet { +namespace op { + +inline void CastStorageDnsRspImpl(mshadow::Stream* s, const TBlob& dns, NDArray* rsp) { + LOG(FATAL) << "CastStorageDnsRspImpl gpu version is not implemented."; +} + +inline void CastStorageDnsCsrImpl(mshadow::Stream* s, const TBlob& dns, NDArray* csr) { + LOG(FATAL) << "CastStorageDnsCsrImpl gpu version is not implemented."; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NN_CAST_STORAGE_INL_CUH_ diff --git a/src/operator/nn/cast_storage-inl.h b/src/operator/nn/cast_storage-inl.h new file mode 100644 index 000000000000..1fb32045b9a0 --- /dev/null +++ b/src/operator/nn/cast_storage-inl.h @@ -0,0 +1,335 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file cast_storage-inl.h + * \brief cast_storage implementation for dense and sparse tensors + */ +#ifndef MXNET_OPERATOR_NN_CAST_STORAGE_INL_H_ +#define MXNET_OPERATOR_NN_CAST_STORAGE_INL_H_ + +#include +#include +#include +#include "../mxnet_op.h" +#include "../operator_common.h" +#ifdef __CUDACC__ +#include "./cast_storage-inl.cuh" +#endif // __CUDACC__ + + +namespace mxnet { +namespace op { + +/*! + * \brief Kernel for marking row_idx of a RSP matrix per row + */ +struct MarkRspRowIdx { + // i represents the row index of the matrix data + template + MSHADOW_XINLINE static void Map(int i, RType* row_idx, const DType* data, + const index_t num_cols) { + index_t j = 0; + index_t offset = i * num_cols; + for (; j < num_cols; ++j) { + if (data[offset+j] != 0) { + break; + } + } + if (num_cols == j) { + row_idx[i] = 0; // mark as zero for zero row + } else { + row_idx[i] = 1; // mark as one for non-zero row + } + } +}; + +/*! + * \brief + * CPU implementation of casting a dns tensor to rsp type. + */ +inline void CastStorageDnsRspImpl(mshadow::Stream* s, const TBlob& dns, NDArray* rsp) { + CHECK(rsp != nullptr); + CHECK_EQ(rsp->storage_type(), kRowSparseStorage); + CHECK_EQ(dns.shape_, rsp->shape()); + MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type + MSHADOW_INT_TYPE_SWITCH(rsp->aux_type(rowsparse::kIdx), RType, { // row idx type + const index_t num_rows = dns.shape_[0]; + const index_t num_cols = dns.shape_[1]; + rsp->CheckAndAllocAuxData(rowsparse::kIdx, mshadow::Shape1(num_rows)); + TBlob row_idx_blob = rsp->aux_data(rowsparse::kIdx); + RType* row_idx = row_idx_blob.dptr(); + mxnet_op::Kernel::Launch(s, num_rows, row_idx, + dns.dptr(), num_cols); + index_t nnr = 0; + nnr = mxnet::common::ParallelAccumulate(row_idx, num_rows, nnr); + rsp->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); + if (0 == nnr) return; + rsp->CheckAndAllocData(mshadow::Shape2(nnr, num_cols)); + mshadow::Tensor dns_data = dns.FlatTo2D(s); + mshadow::Tensor rsp_data = rsp->data().FlatTo2D(s); + size_t idx = 0; + for (index_t i = 0; i < num_rows; ++i) { + if (row_idx[i] > 0) { + row_idx[idx] = i; + mshadow::Copy(rsp_data[idx], dns_data[i], s); + ++idx; + } + } + }); + }); +} + +// TODO(haibin) Use memcopy instead will be much faster than assigning each individual element +struct CastStorageRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, const index_t width, const IType* idx, const DType *data, + DType* dns) { + auto rid = idx[i]; + auto dns_offset = rid * width; + auto rsp_offset = i * width; + for (size_t col = 0; col < width; col++) { + dns[dns_offset + col] = data[rsp_offset + col]; + } + } +}; + +/*! + * \brief This function assumes that the meomry for dns has been allocated already + * since the shape is known at binding stage. + */ +template +void CastStorageRspDnsImpl(mshadow::Stream* s, const NDArray& rsp, TBlob* dns) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(rsp.storage_type(), kRowSparseStorage); + MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { + MSHADOW_INT_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, { + // assign zeros + mxnet_op::Kernel::Launch(s, dns->Size(), dns->dptr()); + if (rsp.storage_initialized()) { + // copy over row by row + auto in_idx = rsp.aux_data(rowsparse::kIdx).FlatTo1D(s).dptr_; + auto in_data = rsp.data().FlatTo2D(s).dptr_; + auto out_data = dns->FlatTo2D(s).dptr_; + auto num_rows = rsp.aux_shape(rowsparse::kIdx).Size(); + auto rsp_shape = rsp.shape(); + auto width = rsp_shape.ProdShape(1, rsp_shape.ndim()); + mxnet_op::Kernel::Launch(s, num_rows, width, in_idx, + in_data, out_data); + } + }); + }); +} + +/*! + * \brief This is the kernel for initializing the indptr in a csr tensor. + */ +struct FillCsrIndPtr { + /*! + * \brief + * \param i the i-th row of the dns tensor + * \param indptr indptr of the csr tensor + * \param dns the dns tensor + * \param num_rows + * \param num_cols + */ + template + MSHADOW_XINLINE static void Map(int i, IType* indptr, const DType* dns, + const int num_rows, const int num_cols) { + indptr[i+1] = 0; + const int offset = i * num_cols; + for (int j = 0; j < num_cols; ++j) { + if (dns[offset+j] != 0) { + ++indptr[i+1]; + } + } + } +}; + +/*! + * \brief This is the kernel for initializing the col_idx and value array + * of the csr tensor + */ +struct FillCsrColIdxAndVals { + /*! + * \brief + * \param i the i-th row of the dns tensor + * \param val value array of the csr + * \param col_idx column idx array of the csr + * \param indptr indptr array of the csr + * \param dns the dns tensor + * \param num_rows number of rows of the dns + * \param num_cols number of columns of the dns + */ + template + MSHADOW_XINLINE static void Map(int i, DType* val, CType* col_idx, + const IType* indptr, const DType* dns, + const int num_rows, const int num_cols) { + const int offset = i * num_cols; + int k = indptr[i]; + for (int j = 0; j < num_cols; ++j) { + if (dns[offset+j] != 0) { + val[k] = dns[offset+j]; + col_idx[k] = j; + ++k; + } + } + } +}; + +/*! + * \brief + * CPU implementation of casting a dns tensor to csr type. + */ +inline void CastStorageDnsCsrImpl(mshadow::Stream* s, const TBlob& dns, NDArray* csr) { + CHECK(csr != nullptr); + CHECK_EQ(csr->storage_type(), kCSRStorage); + CHECK_EQ(dns.shape_.ndim(), 2); + CHECK_EQ(dns.shape_, csr->shape()); + MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type + MSHADOW_INT_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type + MSHADOW_INT_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col idx type + const index_t num_rows = dns.shape_[0]; + const index_t num_cols = dns.shape_[1]; + csr->CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(num_rows+1)); + IType* indptr = csr->aux_data(csr::kIndPtr).dptr(); + DType* dns_data = dns.dptr(); + mxnet_op::Kernel::Launch(s, num_rows, indptr, + dns_data, num_rows, num_cols); + // single thread to accumulate indptr + // indptr[num_rows] indicates the number of non-zero elements + indptr[0] = 0; + for (index_t i = 0; i < num_rows; ++i) { + indptr[i+1] += indptr[i]; + } + // allocate column idx array and value array + csr->CheckAndAllocAuxData(csr::kIdx, + mshadow::Shape1(static_cast(indptr[num_rows]))); + csr->CheckAndAllocData(mshadow::Shape1(static_cast(indptr[num_rows]))); + // fill col_idx and value arrays of the csr + mxnet_op::Kernel::Launch(s, num_rows, + csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), + indptr, dns_data, num_rows, num_cols); + }); + }); + }); +} + +/*! + * \brief This is the kernel for copying csr.data to its corresponding dns tensor. + */ +struct CopyCsrDataToDns { + /*! + * \brief + * \param i the i-th row of the dns tensor + * \param dns_data data blob of the dns tensor + * \param col_idx column idx array of the csr + * \param indptr indptr array of the csr + * \param csr_data data blob of the csr tensor + * \param num_cols number of columns of the dns + */ + template + MSHADOW_XINLINE static void Map(int i, DType* dns_data, const CType* col_idx, + const IType* indptr, const DType* csr_data, + const int num_cols) { + const int offset = i * num_cols; + for (auto j = indptr[i]; j < indptr[i+1]; ++j) { + dns_data[offset+col_idx[j]] = csr_data[j]; + } + } +}; + +/*! + * \brief Casts a csr tensor to dns format. + */ +template +void CastStorageCsrDnsImpl(mshadow::Stream* s, const NDArray& csr, TBlob* dns) { + CHECK(dns != nullptr); + CHECK_EQ(csr.storage_type(), kCSRStorage); + CHECK_EQ(dns->shape_.ndim(), 2); + CHECK_EQ(dns->shape_, csr.shape()); + MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { // data type + MSHADOW_INT_TYPE_SWITCH(csr.aux_type(csr::kIndPtr), IType, { // indptr type + MSHADOW_INT_TYPE_SWITCH(csr.aux_type(csr::kIdx), CType, { // col idx type + const index_t num_rows = dns->shape_[0]; + const index_t num_cols = dns->shape_[1]; + DType* dns_data = dns->dptr(); + mxnet_op::Kernel::Launch(s, dns->shape_.Size(), dns_data); + if (!csr.storage_initialized()) return; + const IType* indptr = csr.aux_data(csr::kIndPtr).dptr(); + const CType* col_idx = csr.aux_data(csr::kIdx).dptr(); + const DType* csr_data = csr.data().dptr(); + mxnet_op::Kernel::Launch(s, num_rows, dns_data, + col_idx, indptr, csr_data, num_cols); + }); + }); + }); +} + +template +void CastStorageComputeImpl(mshadow::Stream* s, + const NDArray& input, + const NDArray& output) { + using namespace mshadow; + using namespace mshadow::expr; + const auto src_stype = input.storage_type(); + const auto dst_stype = output.storage_type(); + if (src_stype == kRowSparseStorage && dst_stype == kDefaultStorage) { + TBlob ret = output.data(); + CastStorageRspDnsImpl(s, input, &ret); + } else if (src_stype == kDefaultStorage && dst_stype == kRowSparseStorage) { + NDArray ret = output; // get rid of the const qualifer + CastStorageDnsRspImpl(s, input.data(), &ret); + } else if (src_stype == kDefaultStorage && dst_stype == kCSRStorage) { + NDArray ret = output; // get rid of the const qualifer + CastStorageDnsCsrImpl(s, input.data(), &ret); + } else if (src_stype == kCSRStorage && dst_stype == kDefaultStorage) { + TBlob ret = output.data(); + CastStorageCsrDnsImpl(s, input, &ret); + } else { + LOG(FATAL) << "Not implemented"; + } +} + +struct CastStorageParam : public dmlc::Parameter { + int storage_type; + DMLC_DECLARE_PARAMETER(CastStorageParam) { + DMLC_DECLARE_FIELD(storage_type) + .add_enum("default", kDefaultStorage) + .add_enum("row_sparse", kRowSparseStorage) + .add_enum("csr", kCSRStorage) + .describe("Output storage type."); + } +}; + +inline bool CastStorageInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + CHECK_NE(in_attrs->at(0), kUndefinedStorage) + << "src ndarray's storage type must be specified"; + const CastStorageParam& param = nnvm::get(attrs.parsed); + CHECK_NE(param.storage_type, kUndefinedStorage) + << "dst ndarray's storage type must be specified"; + TYPE_ASSIGN_CHECK(*out_attrs, 0, param.storage_type); + return true; +} + +template +void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), 1); + CastStorageComputeImpl(s, inputs[0], outputs[0]); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NN_CAST_STORAGE_INL_H_ diff --git a/src/operator/nn/cast_storage.cc b/src/operator/nn/cast_storage.cc new file mode 100644 index 000000000000..21c13e8fa564 --- /dev/null +++ b/src/operator/nn/cast_storage.cc @@ -0,0 +1,31 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file cast_storage.cc + * \brief CPU Implementation of cast_storage operator. + */ + +#include "./cast_storage-inl.h" +#include "../elemwise_op_common.h" +#include "../tensor/elemwise_unary_op.h" + +namespace mxnet { +namespace op { + +// TODO(haibin) declare backward op for cast storage +DMLC_REGISTER_PARAMETER(CastStorageParam); +NNVM_REGISTER_OP(cast_storage) +.describe(R"code(Casts tensor storage type to the new type. +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferStorageType", CastStorageInferStorageType) +.set_attr("FCompute", IdentityCompute) +.set_attr("FComputeEx", CastStorageComputeEx) +.add_argument("data", "NDArray-or-Symbol", "The input.") +.add_arguments(CastStorageParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/nn/cast_storage.cu b/src/operator/nn/cast_storage.cu new file mode 100644 index 000000000000..79f369fb2054 --- /dev/null +++ b/src/operator/nn/cast_storage.cu @@ -0,0 +1,17 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file cast_storage.cu + * \brief GPU Implementation of cast_storage operator. + */ +#include "./cast_storage-inl.h" +#include "../tensor/elemwise_unary_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(cast_storage) +.set_attr("FCompute", IdentityCompute) +.set_attr("FComputeEx", CastStorageComputeEx); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index a43d092bceb6..3d88c9047e3a 100755 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -11,12 +11,15 @@ #include #include #include +#include +#include #include #include #include #include #include #include "../common/cuda_utils.h" +#include "../common/utils.h" namespace mxnet { namespace op { @@ -107,6 +110,19 @@ inline std::string type_string(const int& x) { return "unknown"; } +/*! \brief get string representation of storage_type */ +inline std::string stype_string(const int& x) { + switch (x) { + case kDefaultStorage: + return "default"; + case kCSRStorage: + return "csr"; + case kRowSparseStorage: + return "row_sparse"; + } + return "unknown"; +} + /*! * \brief Assign x to y. Checks for compatiblity when y is not empty. * Allow missing dim in both x and y (as 0). @@ -183,6 +199,24 @@ inline bool type_assign(int *y, const int& x) { } \ } +/*! + * \brief macro assign type to out if out is unknown (-1) otherwise check consistency + * Use macro so we can see the error file more clearly + * \param type_array the storage type array to store the result + * \param index the index of in the array + * \param type the inferred storage type + */ +#define STORAGE_TYPE_ASSIGN_CHECK(type_array, index, type) \ + { \ + if (!type_assign(&(type_array)[index], type)) { \ + std::ostringstream os; \ + os << "Storage type inconsistent, Provided=" \ + << stype_string((type_array)[index]) << ',' \ + << " inferred storage type=" << stype_string(type); \ + throw ::mxnet::op::InferTypeError(os.str(), index); \ + } \ + } + // helper macro to implement bind dispatch #if MXNET_USE_CUDA #define DO_BIND_DISPATCH(Method, ...) \ @@ -315,6 +349,33 @@ inline void ParamParser(nnvm::NodeAttrs* attrs) { attrs->parsed = std::move(param); } +template +void FCompExFallback(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + FCompute fcompute, + const std::string& fname) { + using namespace mxnet::common; + std::vector in_blobs, out_blobs; + std::vector temp_in, temp_out; + GetDefaultBlobs(inputs, &in_blobs, &temp_in, ctx, true); + GetDefaultBlobs(outputs, &out_blobs, &temp_out, ctx, true); + fcompute(attrs, ctx, in_blobs, req, out_blobs); + CastNonDefaultStorage(outputs, temp_out, ctx, true); +} + +#define CHECK_RSP_ALL_ROWS_NON_ZERO(rsp, func, param) \ + { \ + CHECK(rsp.storage_shape()[0] == rsp.shape()[0]) << func \ + << " for RowSparse " << param << " is only implemented for " \ + << "RowSparse " << param << " with all rows containing non-zeros. " \ + << "Expects " << param << ".values.shape[0] (" << rsp.storage_shape()[0] \ + << ") == " << param << ".shape[0] (" << rsp.shape()[0] << ")."; \ + } + + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_OPERATOR_COMMON_H_ diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 85091c008ab4..176da461f31f 100755 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -18,6 +18,7 @@ #include "./mshadow_op.h" #include "./elemwise_op_common.h" #include "mxnet_op.h" +#include "./tensor/init_op.h" namespace mxnet { namespace op { @@ -84,6 +85,173 @@ inline void SGDUpdate(const nnvm::NodeAttrs& attrs, }); } +/*! \brief kernel for sparse sgd + */ +template +struct SGDDnsRspKernel { + // DType is the output data type + // IType is row sparse idx type + // i is the ith row in row sparse gradient + template + MSHADOW_XINLINE static void Map(int i, size_t width, DType* out, const DType* weight, + const IType* grad_idx, const DType *grad_val, + const DType clip_gradient, const DType lr, + const DType wd, const DType rescale_grad) { + for (size_t j = 0; j < width; j++) { + uint64_t data_i = grad_idx[i] * width + j; + uint64_t grad_i = i * width + j; + if (clip_gradient >= 0.0f) { + KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] - + (lr) * mshadow_op::clip::Map(rescale_grad * grad_val[grad_i], clip_gradient)); + } else { + KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] - + (lr * rescale_grad) * grad_val[grad_i]); + } + } + } +}; + +template +inline void SGDUpdateDnsRspImpl(const SGDParam& param, + const OpContext &ctx, + const TBlob& weight, + const NDArray& grad, + const OpReqType& req, + TBlob *out) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + using namespace mxnet_op; + Stream* s = ctx.get_stream(); + CHECK_EQ(grad.storage_type(), kRowSparseStorage); + // if gradients are zeros, no weights are updated + if (!grad.storage_initialized() || req == kNullOp) return; + CHECK_GT(weight.shape_.Size(), 0); + + MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { + MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + auto weight_data = weight.dptr(); + auto grad_idx = grad.aux_data(rowsparse::kIdx).dptr(); + auto grad_val = grad.data().dptr(); + auto num_rows = grad.aux_shape(rowsparse::kIdx)[0]; + auto width = weight.shape_.ProdShape(1, weight.ndim()); + Kernel, xpu>::Launch(s, num_rows, width, + out->dptr(), weight_data, grad_idx, grad_val, + static_cast(param.clip_gradient), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad)); + }); + }); + }); +} + +/*! \brief kernel for sparse sgd + */ +template +struct SGDRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, size_t num_cols, DType* out, const DType* weight, + const DType *grad, const DType clip_gradient, const DType lr, + const DType wd, const DType rescale_grad) { + bool contains_non_zeros = false; + index_t j = 0; + index_t offset = i * num_cols; + for (; j < num_cols; ++j) { + if (grad[offset + j] != 0) { + contains_non_zeros = true; + break; + } + } + if (!contains_non_zeros) return; + const DType rate = 1.f - lr * wd; + for (index_t j = 0; j < num_cols; j++) { + auto index = offset + j; + if (clip_gradient >= 0.0f) { + KERNEL_ASSIGN(out[index], req, rate * weight[index] - + lr * mshadow_op::clip::Map(rescale_grad * grad[index], clip_gradient)); + } else { + KERNEL_ASSIGN(out[index], req, rate * weight[index] - + lr * rescale_grad * grad[index]); + } + } + } +}; + +template +inline void SGDUpdateRspDnsImpl(const SGDParam& param, + const OpContext &ctx, + const NDArray& weight, + const TBlob& grad, + const OpReqType req, + NDArray *out) { + using namespace mshadow; + using namespace mxnet_op; + using namespace rowsparse; + CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights"); + CHECK_EQ(weight.storage_type(), kRowSparseStorage); + if (req == kNullOp) return; + CHECK(weight.storage_initialized()); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + auto weight_data = weight.data().dptr(); + auto grad_data = grad.dptr(); + auto num_rows = weight.aux_shape(kIdx)[0]; + auto num_cols = weight.shape().ProdShape(1, weight.shape().ndim()); + Kernel, xpu>::Launch(s, num_rows, num_cols, + out->data().dptr(), weight_data, grad_data, + static_cast(param.clip_gradient), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad)); + }); + }); +} + +template +inline void SGDUpdateRspRspImpl(const SGDParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const OpReqType& req, + NDArray *out) { + CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights"); + // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only + // feed in kWriteTo as req for all operators. + // For sgd we don't want to assign zeros to the output values when req == kWriteTo + auto out_req = req; + if (out_req == kWriteTo) out_req = kWriteInplace; + // reuse dns rsp implementation when storage_shape == shape + TBlob out_blob = out->data(); + SGDUpdateDnsRspImpl(param, ctx, weight.data(), grad, out_req, &out_blob); +} + +template +inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + const SGDParam& param = nnvm::get(attrs.parsed); + auto weight_stype = inputs[0].storage_type(); + auto grad_stype = inputs[1].storage_type(); + if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage) { + TBlob out = outputs[0].data(); + SGDUpdateDnsRspImpl(param, ctx, inputs[0].data(), inputs[1], req[0], &out); + } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + SGDUpdateRspRspImpl(param, ctx, inputs[0], inputs[1], req[0], &out); + } else if (weight_stype == kRowSparseStorage && grad_stype == kDefaultStorage) { + NDArray out = outputs[0]; + SGDUpdateRspDnsImpl(param, ctx, inputs[0], inputs[1].data(), req[0], &out); + } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage) { + FCompExFallback(attrs, ctx, inputs, req, outputs, SGDUpdate, "SGDUpdate"); + } +} + struct SGDMomParam : public dmlc::Parameter { float lr; float momentum; @@ -153,6 +321,206 @@ inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs, }); } +template +struct SGDMomDnsRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, size_t width, DType* out_data, + DType* mom_data, const DType* weight_data, const IType* grad_idx, + const DType* grad_data, const DType clip_gradient, const DType momentum, + const DType lr, const DType wd, const DType rescale_grad) { + const DType rate = lr * wd; + for (size_t j = 0; j < width; j++) { + uint64_t data_i = grad_idx[i] * width + j; + uint64_t grad_i = i * width + j; + if (clip_gradient >= 0.0f) { + mom_data[data_i] = momentum * mom_data[data_i] + - rate * weight_data[data_i] + - lr * + mshadow_op::clip::Map(rescale_grad * grad_data[grad_i], + clip_gradient); + } else { + mom_data[data_i] = momentum * mom_data[data_i] + - rate * weight_data[data_i] + - lr * rescale_grad * grad_data[grad_i]; + } + KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]); + } + } +}; + +template +inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param, + const OpContext& ctx, + const TBlob& weight, + const NDArray& grad, + const TBlob& mom, + const OpReqType& req, + TBlob *out) { + using namespace mxnet_op; + using namespace rowsparse; + Stream* s = ctx.get_stream(); + if (!grad.storage_initialized() || req == kNullOp) return; + CHECK_GT(weight.shape_.Size(), 0); + CHECK_GT(mom.shape_.Size(), 0); + + MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { + MSHADOW_INT_TYPE_SWITCH(grad.aux_type(kIdx), IType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + auto weight_data = weight.dptr(); + auto grad_idx = grad.aux_data(kIdx).dptr(); + auto grad_val = grad.data().dptr(); + auto mom_data = mom.dptr(); + auto out_data = out->dptr(); + auto num_rows = grad.aux_shape(kIdx)[0]; + auto width = weight.shape_.ProdShape(1, weight.ndim()); + Kernel, xpu>::Launch(s, num_rows, width, + out_data, mom_data, weight_data, grad_idx, grad_val, + static_cast(param.clip_gradient), static_cast(param.momentum), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad)); + }); + }); + }); +} + +template +struct SGDMomRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, size_t num_cols, DType* out, DType* mom, + const DType* weight, const DType *grad, + const DType clip_gradient, const DType momentum, + const DType lr, const DType wd, const DType rescale_grad) { + bool contains_non_zeros = false; + index_t j = 0; + index_t offset = i * num_cols; + for (; j < num_cols; ++j) { + if (grad[offset + j] != 0) { + contains_non_zeros = true; + break; + } + } + if (!contains_non_zeros) return; + const DType rate = lr * wd; + for (index_t j = 0; j < num_cols; j++) { + auto index = offset + j; + if (clip_gradient >= 0.0f) { + mom[index] = momentum * mom[index] - rate * weight[index] + - lr * mshadow_op::clip::Map(rescale_grad * grad[index], clip_gradient); + } else { + mom[index] = momentum * mom[index] - rate * weight[index] + - lr * rescale_grad * grad[index]; + } + KERNEL_ASSIGN(out[index], req, weight[index] + mom[index]); + } + } +}; + +template +inline void SGDMomUpdateRspDnsImpl(const SGDMomParam& param, + const OpContext &ctx, + const NDArray& weight, + const TBlob& grad, + const NDArray& mom, + const OpReqType req, + NDArray *out) { + using namespace mshadow; + using namespace mxnet_op; + using namespace rowsparse; + CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights"); + Stream* s = ctx.get_stream(); + CHECK_EQ(weight.storage_type(), kRowSparseStorage); + if (req == kNullOp) return; + CHECK(weight.storage_initialized()); + // fill mom with zero values if not initialized yet + if (!mom.storage_initialized()) { + NDArray mom_zeros = mom; + FillDnsZerosRspImpl(s, &mom_zeros); + } + // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only + // feed in kWriteTo as req for all operators. + // For sgd we don't want to assign zeros to the output values when req == kWriteTo + auto out_req = req; + if (out_req == kWriteTo) out_req = kWriteInplace; + MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { + MXNET_ASSIGN_REQ_SWITCH(out_req, req_type, { + auto weight_data = weight.data().dptr(); + auto grad_data = grad.dptr(); + auto mom_data = mom.data().dptr(); + auto num_rows = weight.aux_shape(kIdx)[0]; + auto num_cols = weight.shape().ProdShape(1, weight.shape().ndim()); + Kernel, xpu>::Launch(s, num_rows, num_cols, + out->data().dptr(), mom_data, weight_data, grad_data, + static_cast(param.clip_gradient), static_cast(param.momentum), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad)); + }); + }); +} + + +template +inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const NDArray& mom, + const OpReqType& req, + NDArray *out) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + using namespace rowsparse; + CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights"); + Stream* s = ctx.get_stream(); + // fill mom with zero values in order to reuse the sgd mom dns impl + if (!mom.storage_initialized()) { + NDArray mom_zeros = mom; + FillDnsZerosRspImpl(s, &mom_zeros); + } + // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only + // feed in kWriteTo as req for all operators. + // For sgd we don't want to assign zeros to the output values when req == kWriteTo + auto out_req = req; + if (out_req == kWriteTo) out_req = kWriteInplace; + TBlob out_blob = out->data(); + // reuse dns rsp implementation when storage_shape == shape + SGDMomUpdateDnsRspDnsImpl(param, ctx, weight.data(), grad, + mom.data(), out_req, &out_blob); +} + +template +inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + const SGDMomParam& param = nnvm::get(attrs.parsed); + auto &weight = inputs[0]; + auto &grad = inputs[1]; + auto &mom = inputs[2]; + auto weight_stype = weight.storage_type(); + auto grad_stype = grad.storage_type(); + auto mom_stype = mom.storage_type(); + if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage && + mom_stype == kDefaultStorage) { + TBlob out = outputs[0].data(); + SGDMomUpdateDnsRspDnsImpl(param, ctx, weight.data(), grad, + mom.data(), req[0], &out); + } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage && + mom_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + SGDMomUpdateRspRspRspImpl(param, ctx, weight, grad, mom, req[0], &out); + } else if (weight_stype == kRowSparseStorage && grad_stype == kDefaultStorage && + mom_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + SGDMomUpdateRspDnsImpl(param, ctx, weight, grad.data(), mom, req[0], &out); + } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage && + mom_stype == kDefaultStorage) { + FCompExFallback(attrs, ctx, inputs, req, outputs, SGDMomUpdate, "SGDMomUpdate"); + } +} + struct AdamParam : public dmlc::Parameter { float lr; float beta1; diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 9ec6aacaafac..5c8bedcb0ebc 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -22,6 +22,9 @@ It updates the weights using:: weight = weight - learning_rate * gradient +If weights are stored with `row_sparse` storage, +update is applied only to rows whose gradient has non-zero entries. + )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -29,6 +32,7 @@ It updates the weights using:: .set_attr("FInferShape", ElemwiseShape<2, 1>) .set_attr("FInferType", ElemwiseType<2, 1>) .set_attr("FCompute", SGDUpdate) +.set_attr("FComputeEx", SGDUpdateEx) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_arguments(SGDParam::__FIELDS__()); @@ -52,6 +56,9 @@ It updates the weights using:: Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. +If weights are stored with `row_sparse` storage, +only rows whose gradients contain non-zero entries are updated (for both weight and momentum). + )code" ADD_FILELINE) .set_num_inputs(3) .set_num_outputs(1) @@ -63,12 +70,12 @@ Where the parameter ``momentum`` is the decay rate of momentum estimates at each return std::vector{2}; }) .set_attr("FCompute", SGDMomUpdate) +.set_attr("FComputeEx", SGDMomUpdateEx) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_argument("mom", "NDArray-or-Symbol", "Momentum") .add_arguments(SGDMomParam::__FIELDS__()); - NNVM_REGISTER_OP(adam_update) .describe(R"code(Update function for Adam optimizer. Adam is seen as a generalization of AdaGrad. diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index 2b2667ec317b..3445bafc87cc 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -10,10 +10,12 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(sgd_update) -.set_attr("FCompute", SGDUpdate); +.set_attr("FCompute", SGDUpdate) +.set_attr("FComputeEx", SGDUpdateEx); NNVM_REGISTER_OP(sgd_mom_update) -.set_attr("FCompute", SGDMomUpdate); +.set_attr("FCompute", SGDMomUpdate) +.set_attr("FComputeEx", SGDMomUpdateEx); NNVM_REGISTER_OP(adam_update) .set_attr("FCompute", AdamUpdate); diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc index 0d0a1d8b5df0..f6f8f429d99e 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc @@ -105,6 +105,7 @@ Example:: .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"}); + NNVM_REGISTER_OP(_backward_broadcast_mul) .set_num_inputs(3) .set_num_outputs(2) diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index 6062febe2d9e..222b0d1ffc31 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -10,10 +10,11 @@ #include #include #include +#include #include "../mxnet_op.h" #include "../mshadow_op.h" #include "../elemwise_op_common.h" -#include "../mxnet_op.h" +#include "../../common/utils.h" namespace mxnet { namespace op { @@ -123,6 +124,109 @@ void BinaryBackwardUseNone_(const nnvm::NodeAttrs& attrs, } } +// TODO(haibin) This is a single-thread inefficient implementation +// Binary Compute between two row-sparse ndarray +// This implementation only works on CPU +template +void BinaryComputeRspRsp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + auto &lhs = inputs[0]; + auto &rhs = inputs[1]; + auto &output = outputs[0]; + + bool init_l = lhs.storage_initialized(); + bool init_r = rhs.storage_initialized(); + // both inputs are zeros + if (!init_l && !init_r) return; + // Memory Estimation: This is (roughly) the number of result rows. We still + // need to subtract the number of common rows + unsigned int num_rows_l = lhs.aux_shape(rowsparse::kIdx).Size(); + unsigned int num_rows_r = rhs.aux_shape(rowsparse::kIdx).Size(); + output.CheckAndAlloc({mshadow::Shape1(num_rows_l + num_rows_r)}); + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(output.dtype(), DType, { + MSHADOW_TYPE_SWITCH(lhs.aux_type(rowsparse::kIdx), IType, { + // Indices + auto indices_l = lhs.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto indices_r = rhs.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto indices_out = output.aux_data(rowsparse::kIdx).FlatTo1D(s); + // Data + auto data_l = lhs.data().FlatTo2D(s); + auto data_r = rhs.data().FlatTo2D(s); + auto out = output.data().FlatTo2D(s); + + // TODO(haibin) A more appropriate way: Copy to output, then apply ops + size_t iter_l = 0; + size_t iter_r = 0; + size_t iter_out = 0; + int32_t num_common_rows = 0; + while (iter_l < num_rows_l && iter_r < num_rows_r) { + auto idx_l = indices_l[iter_l]; + auto idx_r = indices_r[iter_r]; + if (idx_l == idx_r) { + // Same row + indices_out[iter_out] = idx_l; + mshadow::Copy(out[iter_out], data_l[iter_l++], s); + out[iter_out] += data_r[iter_r++]; + num_common_rows++; + } else if (idx_l < idx_r) { + // Left only + indices_out[iter_out] = idx_l; + mshadow::Copy(out[iter_out], data_l[iter_l++], s); + } else { + // Right only + indices_out[iter_out] = idx_r; + mshadow::Copy(out[iter_out], data_r[iter_r++], s); + } + iter_out++; + } + // Copying over the rest of the rows + while (iter_l < num_rows_l) { + indices_out[iter_out] = indices_l[iter_l]; + mshadow::Copy(out[iter_out++], data_l[iter_l++], s); + } + while (iter_r < num_rows_r) { + indices_out[iter_out] = indices_r[iter_r]; + mshadow::Copy(out[iter_out++], data_r[iter_r++], s); + } + auto new_shape = output.aux_shape(rowsparse::kIdx); + new_shape[0] -= num_common_rows; + output.set_aux_shape(rowsparse::kIdx, new_shape); + }); + }); +} + +template +void BinaryComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2); + CHECK_EQ(outputs.size(), 1); + if (typeid(OP) == typeid(mshadow::op::plus)) { + // If any input is dense, fallback to FCompute + // TODO(haibin) implement dns + rsp in a separate kernel + if (mxnet::common::ContainsDefaultStorage(inputs)) { + FCompExFallback(attrs, ctx, inputs, req, outputs, + BinaryCompute, "BinaryCompute"); + return; + } + CHECK_EQ(inputs[0].storage_type(), kRowSparseStorage) << "Sparse type not supported yet"; + CHECK_EQ(inputs[1].storage_type(), kRowSparseStorage) << "Sparse type not supported yet"; + BinaryComputeRspRsp(attrs, ctx, inputs, req, outputs); + return; + } else { + LOG(FATAL) << "Not implemented"; + } +} + template void BinaryBackwardUseNone(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -134,6 +238,55 @@ void BinaryBackwardUseNone(const nnvm::NodeAttrs& attrs, }); } +// Only implemented for _backward_add for now +template +void BinaryBackwardUseNoneRsp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs[0].storage_type(), kRowSparseStorage); + CHECK_EQ(outputs[0].storage_type(), kRowSparseStorage); + CHECK_EQ(outputs[1].storage_type(), kRowSparseStorage); + CHECK(typeid(LOP) == typeid(mshadow_op::identity)); + CHECK(typeid(ROP) == typeid(mshadow_op::identity)); + TShape shape = inputs[0].aux_shape(rowsparse::kIdx); + outputs[0].CheckAndAlloc({shape}); + outputs[1].CheckAndAlloc({shape}); + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + MSHADOW_TYPE_SWITCH(outputs[0].aux_type(rowsparse::kIdx), IType, { + auto lgrad_idx = outputs[0].aux_data(rowsparse::kIdx).FlatTo1D(s); + auto rgrad_idx = outputs[1].aux_data(rowsparse::kIdx).FlatTo1D(s); + auto ograd_idx = inputs[0].aux_data(rowsparse::kIdx).FlatTo1D(s); + auto lgrad = outputs[0].data().FlatTo1D(s); + Tensor rgrad = outputs[1].data().FlatTo1D(s); + Tensor ograd = inputs[0].data().FlatTo1D(s); + ASSIGN_DISPATCH(lgrad, req[0], F(ograd)); + ASSIGN_DISPATCH(rgrad, req[1], F(ograd)); + ASSIGN_DISPATCH(lgrad_idx, req[0], F(ograd_idx)); + ASSIGN_DISPATCH(rgrad_idx, req[1], F(ograd_idx)); + }); + }); +} +// Only implemented for _backward_add for now +template +void BinaryBackwardUseNoneEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + auto stype = inputs[0].storage_type(); + CHECK_EQ(stype, kRowSparseStorage) << "Not implemented yet"; + BinaryBackwardUseNoneRsp(attrs, ctx, inputs, req, outputs); + // TODO(haibin) fallback for kDefaultStorage +} + template void BinaryBackwardUseNoneWithHalf2(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -214,7 +367,7 @@ void BinaryBackwardUseInWithHalf2(const nnvm::NodeAttrs& attrs, [](const NodeAttrs& attrs){ \ return std::vector >{{0, 0}, {1, 0}}; \ }) \ - .add_argument("lhs", "NDArray-or-Symbol", "first input") \ + .add_argument("lhs", "NDArray-or-Symbol", "first input") \ .add_argument("rhs", "NDArray-or-Symbol", "second input") } // namespace op diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index be4c1d88e983..c9e5b21470d9 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -12,7 +12,9 @@ MXNET_OPERATOR_REGISTER_BINARY(elemwise_add) .add_alias("_add").add_alias("_plus").add_alias("_Plus") .describe("Adds arguments element-wise.") .set_attr("FCompute", BinaryCompute) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_add"}); +.set_attr("FComputeEx", BinaryComputeEx) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_add"}) +.set_attr("FInferStorageType", ElemwiseStorageType<2, 1>); // specialized gradient add function to do add to optimization // this must differ from elemwise_add to prevent add to optimization in forward pass. @@ -28,7 +30,10 @@ NNVM_REGISTER_OP(_backward_add) return std::vector >{{0, 0}, {0, 1}}; }) .set_attr("FCompute", BinaryBackwardUseNone); + mshadow_op::identity>) +.set_attr("FComputeEx", + BinaryBackwardUseNoneEx) +.set_attr("FInferStorageType", ElemwiseStorageType<1, 2>); MXNET_OPERATOR_REGISTER_BINARY(_sub) .add_alias("_minus").add_alias("_Minus") diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu index ff432380d6d1..b75ce8118c2f 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_op_basic.cu @@ -9,7 +9,8 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(elemwise_add) -.set_attr("FCompute", BinaryComputeWithHalf2); +.set_attr("FCompute", BinaryComputeWithHalf2) +.set_attr("FComputeEx", BinaryComputeEx); NNVM_REGISTER_OP(_grad_add) .set_attr("FCompute", BinaryComputeWithHalf2); @@ -17,7 +18,9 @@ NNVM_REGISTER_OP(_grad_add) NNVM_REGISTER_OP(_backward_add) .set_attr("FCompute", BinaryBackwardUseNoneWithHalf2); + mshadow_op::identity, mshadow_op::identity>) +.set_attr("FComputeEx", + BinaryBackwardUseNoneEx); NNVM_REGISTER_OP(_sub) .set_attr("FCompute", BinaryComputeWithHalf2); diff --git a/src/operator/tensor/elemwise_unary_op.cc b/src/operator/tensor/elemwise_unary_op.cc index 073bbe16d491..372e94509a68 100644 --- a/src/operator/tensor/elemwise_unary_op.cc +++ b/src/operator/tensor/elemwise_unary_op.cc @@ -124,7 +124,9 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs) .set_attr("FIgnoreInputs", [](const NodeAttrs& attrs) { return std::vector(1, 1); }) .set_attr("FCompute", IdentityCompute) +.set_attr("FComputeEx", IdentityLikeRhsComputeEx) .set_attr("FInferShape", ElemwiseShape<2, 1>) +.set_attr("FInferStorageType", IdentityAttrLikeRhsStorageType) .set_attr( "FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { @@ -169,6 +171,7 @@ NNVM_REGISTER_OP(_backward_cast) .set_attr("TIsBackward", true) .set_attr("FCompute", CastCompute); + // negative MXNET_OPERATOR_REGISTER_UNARY(negative) .MXNET_DESCRIBE("Negate src") diff --git a/src/operator/tensor/elemwise_unary_op.cu b/src/operator/tensor/elemwise_unary_op.cu index 746b39fe4c8c..b8fa59f5d04e 100644 --- a/src/operator/tensor/elemwise_unary_op.cu +++ b/src/operator/tensor/elemwise_unary_op.cu @@ -35,7 +35,9 @@ NNVM_REGISTER_OP(make_loss) // identity output as first input, but attributes are constrainted to be like rhs NNVM_REGISTER_OP(_identity_with_attr_like_rhs) -.set_attr("FCompute", IdentityCompute); +.set_attr("FCompute", IdentityCompute) +.set_attr("FComputeEx", IdentityLikeRhsComputeEx); + NNVM_REGISTER_OP(Cast) .set_attr("FCompute", CastCompute); diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 97a7e36535f0..f3aab781eddb 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -13,15 +13,16 @@ #include "../mshadow_op.h" #include "../elemwise_op_common.h" #include "../special_functions-inl.h" +#include "./broadcast_reduce-inl.h" namespace mxnet { namespace op { template void UnaryLaunch(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mxnet_op; Stream *s = ctx.get_stream(); @@ -77,6 +78,54 @@ void IdentityCompute(const nnvm::NodeAttrs& attrs, }); } +template +void IdentityComputeRsp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + auto &input = inputs[0]; + auto &output = outputs[0]; + CHECK_NE(req[0], kNullOp) << "kNullOp in IdentityComputeEx not supported yet"; + CHECK_NE(req[0], kWriteInplace) << "kWriteInplace in IdentityComputeEx not supported yet"; + if (!input.storage_initialized()) return; + TShape shape = input.aux_shape(rowsparse::kIdx); + output.CheckAndAlloc({shape}); + MSHADOW_TYPE_SWITCH(output.dtype(), DType, { + MSHADOW_TYPE_SWITCH(output.aux_type(rowsparse::kIdx), AuxType, { + auto out_d = output.data().FlatTo1D(s); + auto out_aux = output.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto in_aux = input.aux_data(rowsparse::kIdx).FlatTo1D(s); + ASSIGN_DISPATCH(out_d, req[0], + F(input.data().FlatTo1D(s))); + ASSIGN_DISPATCH(out_aux, req[0], F(in_aux)); + }); + }); +} + +template +void IdentityLikeRhsComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(inputs.size(), 2); + CHECK_EQ(outputs.size(), 1); + Stream *s = ctx.get_stream(); + size_t rhs_idx = 1; + NDArrayStorageType stype = inputs[rhs_idx].storage_type(); + if (stype == kRowSparseStorage) { + IdentityComputeRsp(attrs, ctx, inputs, req, outputs); + } else { + LOG(FATAL) << "Not implemented yet"; + } +} + struct CastParam : public dmlc::Parameter { // use int for enumeration int dtype; @@ -168,4 +217,5 @@ struct relu_grad { } // namespace op } // namespace mxnet + #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_UNARY_OP_H_ diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 5f010fdfc62c..dfe53cf4614e 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -86,6 +86,48 @@ NNVM_REGISTER_OP(_backward_Embedding) .set_attr("TIsBackward", true) .set_attr("FCompute", EmbeddingOpBackward); +NNVM_REGISTER_OP(SparseEmbedding) +.describe(R"doc(Represents words or other sparse inputs by dense continuous vectors. +It assumes that the input is in one-hot form. E.g., for a vocabulary size of 10,000, + each input vector is expected to have dimension 10,000. +The index of the non-zero entry is the index of the word or item it represents. + +The corresponding embedding vectors are stored as rows of a matrix. +Hence, mapping an input word to its embedding is implemented as a matrix product. + +The gradient of an embedding matrix has the form of gradient vectors that are only + non-zero for words seen in a minibatch. +)doc" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "weight"}; + }) +.set_attr("FInferShape", SparseEmbeddingShape) +.set_attr("FInferType", EmbeddingOpType) +.set_attr("FInferStorageType", SparseEmbeddingForwardStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FComputeEx", SparseEmbeddingForwardEx) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + return MakeNonlossGradNode("_backward_SparseEmbedding", n, ograds, + {n->inputs[0]}, n->attrs.dict); + }) +.add_argument("data", "NDArray-or-Symbol", + "The input array to the sparse embedding operator.") +.add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.") +.add_arguments(EmbeddingParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_SparseEmbedding) +.set_num_inputs(2) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FComputeEx", SparseEmbeddingBackwardEx); NNVM_REGISTER_OP(take) .describe(R"code(Takes elements from an input array along the given axis. @@ -230,5 +272,46 @@ Examples:: .add_argument("indices", "NDArray-or-Symbol", "array of locations where to set on_value") .add_arguments(OneHotParam::__FIELDS__()); +NNVM_REGISTER_OP(sparse_retain) +.describe(R"code(pick rows specified by user input index array from a row sparse matrix +and save them in the output sparse matrix. + +Example:: + + data = [[1, 2], [3, 4], [5, 6]] + indices = [0, 1, 3] + shape = (4, 2) + rsp_in = row_sparse(data, indices) + to_retain = [0, 3] + rsp_out = sparse_retain(rsp_in, to_retain) + rsp_out.values = [[1, 2], [5, 6]] + rsp_out.indices = [0, 3] + +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "indices"}; + }) +.set_attr("FInferShape", SparseRetainOpShape) +.set_attr("FInferType", SparseRetainOpType) +.set_attr("FInferStorageType", SparseRetainForwardInferStorageType) +.set_attr("FComputeEx", SparseRetainOpForwardEx) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + return MakeNonlossGradNode("_backward_sparse_retain", n, ograds, + {n->inputs[sr::kIdx]}, n->attrs.dict); + }) +.add_argument("data", "NDArray-or-Symbol", "The input array for sparse_retain operator.") +.add_argument("indices", "NDArray-or-Symbol", "The index array of rows ids that will be retained."); + +NNVM_REGISTER_OP(_backward_sparse_retain) +.set_num_inputs(2) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInferStorageType", SparseRetainBackwardInferStorageType) +.set_attr("FComputeEx", SparseRetainOpBackwardEx); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 287ec25d70be..4378bd574932 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -26,6 +26,12 @@ NNVM_REGISTER_OP(batch_take) NNVM_REGISTER_OP(one_hot) .set_attr("FCompute", OneHotOpForward); +NNVM_REGISTER_OP(sparse_retain) +.set_attr("FComputeEx", SparseRetainOpForwardEx); + +NNVM_REGISTER_OP(_backward_sparse_retain) +.set_attr("FComputeEx", SparseRetainOpBackwardEx); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 5fd6e81d0b2f..b2a67f73af78 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -22,6 +22,7 @@ #include "../elemwise_op_common.h" #include "../mxnet_op.h" #include "./sort_op.h" +#include "./matrix_op-inl.h" namespace mxnet { namespace op { @@ -203,6 +204,78 @@ void EmbeddingOpForward(const nnvm::NodeAttrs& attrs, }); } +template +void SparseEmbeddingForwardRspImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const NDArray& data, + const NDArray& weight, + const OpReqType req, + NDArray *out) { + CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SparseEmbedding", "weight"); + TBlob out_blob = out->data(); + // forward to dns implementation when storage_shape equals shape + bool transpose_a = false; + DotCsrRspDnsImpl(ctx, data, weight, req, transpose_a, &out_blob); +} + +template +void SparseEmbeddingForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(req[embedding::kOut], kWriteTo); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + + NDArray output = outputs[embedding::kOut]; + auto data_stype = inputs[embedding::kData].storage_type(); + auto weight_stype = inputs[embedding::kWeight].storage_type(); + auto out_stype = outputs[embedding::kOut].storage_type(); + if (data_stype == kCSRStorage && weight_stype == kRowSparseStorage && + out_stype == kDefaultStorage) { + NDArray ret = outputs[embedding::kOut]; + SparseEmbeddingForwardRspImpl(attrs, ctx, inputs[embedding::kData], + inputs[embedding::kWeight], + req[embedding::kOut], &ret); + } else { + LOG(FATAL) << "Not supported SparseEmbedding operation for data.storage_type = " + << data_stype << ", weight.storage_type = " << weight_stype + << ", out.storage_type = " << out_stype; + } +} + +inline bool SparseEmbeddingForwardStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, embedding::kData, kCSRStorage); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, embedding::kOut, kDefaultStorage); + // override the default storage type generated in nnvm + in_attrs->at(embedding::kWeight) = kRowSparseStorage; + return true; +} + +inline bool SparseEmbeddingShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + using namespace mshadow; + const EmbeddingParam& param = nnvm::get(attrs.parsed); + const TShape &dshape = (*in_attrs)[embedding::kData]; + CHECK_EQ(dshape.ndim(), 2) + << "SparseEmbedding shape error: data is expected to be 2D."; + SHAPE_ASSIGN_CHECK(*in_attrs, embedding::kWeight, + Shape2(param.input_dim, param.output_dim)); + out_attrs->clear(); + std::vector buf(2); + buf[0] = dshape[0]; + buf[1] = param.output_dim; + out_attrs->emplace_back(buf.begin(), buf.end()); + return true; +} + // Returns integer log2(a) rounded up inline int ilog2(unsigned int a) { int k = 1; @@ -315,6 +388,31 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs, }); } +template +void SparseEmbeddingBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + // CHECK_EQ(req[embedding::kData], kNullOp) + // << "Embedding layer doesn't support calculate data gradient" << req[0] << " " << req[1]; + // CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; + + auto data_stype = inputs[1].storage_type(); + auto grad_stype = inputs[0].storage_type(); + auto output_stype = outputs[1].storage_type(); + if (data_stype == kCSRStorage && grad_stype == kDefaultStorage && + output_stype == kDefaultStorage) { + TBlob ret = outputs[1].data(); + DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], true, &ret); + } else { + LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients"; + } +} + namespace take_ { // to avoid name conflict enum TakeOpInputs {kArr, kIdx}; enum TakeOpOutputs {kOut}; @@ -667,6 +765,199 @@ void OneHotOpForward(const nnvm::NodeAttrs& attrs, }); } +/*! + * \brief sparse retain namespace + */ +namespace sr { +enum SparseRetainOpInputs {kArr, kIdx}; +enum SparseRetainOpOutputs {kOut}; +} // namespace sr + +inline bool SparseRetainOpShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U) + << "sparse_retain operator takes 2 arguments (" << in_attrs->size() << " given)"; + CHECK_EQ(out_attrs->size(), 1U); + + TShape tshape((*in_attrs)[sr::kArr]); + shape_assign(&tshape, (*out_attrs)[sr::kOut]); + SHAPE_ASSIGN_CHECK(*in_attrs, sr::kArr, tshape); + SHAPE_ASSIGN_CHECK(*out_attrs, sr::kOut, tshape); + return true; +} + +inline bool SparseRetainOpType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + CHECK_NE((*in_attrs)[sr::kIdx], -1) << "Index type must be set for sparse_retain operator"; + + TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[sr::kArr]); + TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[sr::kOut]); + return (*in_attrs)[0] != -1; +} + +inline bool SparseRetainForwardInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + if (kRowSparseStorage == in_attrs->at(sr::kArr)) { + out_attrs->at(sr::kOut) = kRowSparseStorage; + } + return true; +} + +inline bool SparseRetainBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 2U); + out_attrs->at(sr::kArr) = kRowSparseStorage; + out_attrs->at(sr::kIdx) = kDefaultStorage; + return true; +} + +struct SparseRetainRspForward { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, RType* out_idx, + const DType* in_data, const RType* in_idx, + const IType* idx, const size_t nnr, + const size_t num_cols) { + const RType irow = idx[i]; + int j = -1, left = 0, right = nnr - 1; + while (left <= right) { + int m = left + (right - left) / 2; + const auto in_idx_m = in_idx[m]; + if (in_idx_m == irow) { + j = m; + break; + } else if (in_idx_m < irow) { + left = m + 1; + } else { + right = m - 1; + } + } + out_idx[i] = idx[i]; + if (j >= 0) { + const size_t in_offset = j * num_cols; + const size_t out_offset = i * num_cols; + for (size_t k = 0; k < num_cols; ++k) { + out_data[out_offset+k] = in_data[in_offset+k]; + } + } + } +}; + +template +void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(req[sr::kOut], kWriteTo) << "sparse_retain only supports req=\'write\'"; + + CHECK_EQ(inputs[sr::kArr].storage_type(), kRowSparseStorage) + << "sparse_retain operator only takes row sparse NDArray as input"; + CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage) + << "sparse_retain operator only takes default NDArray as its index array"; + CHECK_EQ(outputs[sr::kOut].storage_type(), kRowSparseStorage) + << "sparse_retain operator only outputs row sparse NDArray"; + + const NDArray& input_nd = inputs[sr::kArr]; + const TBlob idx_data = inputs[sr::kIdx].data(); + + if (req[sr::kOut] == kNullOp + || !input_nd.storage_initialized() + || idx_data.Size() == 0U) return; + + const TBlob input_data = input_nd.data(); + if (input_data.shape_[0] == 0) return; + const TBlob input_idx = input_nd.aux_data(rowsparse::kIdx); + + NDArray output_nd = outputs[sr::kOut]; + output_nd.CheckAndAlloc({mshadow::Shape1(idx_data.Size())}); + TBlob output_data = output_nd.data(); + TBlob output_idx = output_nd.aux_data(rowsparse::kIdx); + + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(output_data.type_flag_, DType, { // output data type + MSHADOW_INT_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type + MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type + Kernel::Launch(s, output_data.Size(), output_data.dptr()); + Kernel::Launch(s, idx_data.Size(), output_data.dptr(), + output_idx.dptr(), input_data.dptr(), input_idx.dptr(), + idx_data.dptr(), input_data.shape_[0], input_data.shape_[1]); + }); + }); + }); +} + +template +struct SparseRetainRspBackward { + template + MSHADOW_XINLINE static void Map(int i, DType* in_grad, RType* in_grad_idx, + const DType* out_grad, const IType* idx, + const size_t num_cols) { + const RType irow = idx[i]; + in_grad_idx[i] = irow; + const size_t out_offset = irow * num_cols; + const size_t in_offset = i * num_cols; + for (size_t j = 0; j < num_cols; ++j) { + KERNEL_ASSIGN(in_grad[in_offset+j], req, out_grad[out_offset+j]); + } + } +}; + +template +void SparseRetainOpBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + CHECK_NE(req[sr::kArr], kWriteInplace); + CHECK_EQ(req[sr::kIdx], kNullOp) + << "sparse_retain does not support calculating gradients of indices"; + + CHECK_EQ(inputs[sr::kOut].storage_type(), kDefaultStorage) + << "sparse_retain backward only takes default NDArray as ograd"; + CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage) + << "sparse_retain backward only takes default NDArray as its index array"; + CHECK_EQ(outputs[sr::kArr].storage_type(), kRowSparseStorage) + << "sparse_retain backward only outputs row sparse NDArray as grad of input"; + + const TBlob out_grad_data = inputs[sr::kOut].data(); + const TBlob idx_data = inputs[sr::kIdx].data(); + + NDArray in_grad_nd = outputs[sr::kArr]; + in_grad_nd.CheckAndAlloc({mshadow::Shape1(idx_data.Size())}); + TBlob in_grad_data = in_grad_nd.data(); + TBlob in_grad_idx = in_grad_nd.aux_data(rowsparse::kIdx); + + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(out_grad_data.type_flag_, DType, { // output data type + MSHADOW_INT_TYPE_SWITCH(in_grad_idx.type_flag_, RType, { // row index data type + MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type + MXNET_ASSIGN_REQ_SWITCH(req[sr::kArr], req_type, { + Kernel, xpu>::Launch( + s, in_grad_idx.Size(), in_grad_data.dptr(), in_grad_idx.dptr(), + out_grad_data.dptr(), idx_data.dptr(), out_grad_data.shape_[1]); + }); + }); + }); + }); +} + } // namespace op } // namespace mxnet #ifdef __CUDACC__ diff --git a/src/operator/tensor/init_op.cc b/src/operator/tensor/init_op.cc index 16f71fc7e4e3..679d1fb55bab 100644 --- a/src/operator/tensor/init_op.cc +++ b/src/operator/tensor/init_op.cc @@ -21,6 +21,7 @@ NNVM_REGISTER_OP(_zeros) .set_attr("FInferShape", InitShape) .set_attr("FInferType", InitType) .set_attr("FCompute", FillCompute) +.set_attr("FComputeEx", FillComputeZerosEx) .add_arguments(InitOpParam::__FIELDS__()); NNVM_REGISTER_OP(_ones) diff --git a/src/operator/tensor/init_op.cu b/src/operator/tensor/init_op.cu index a798f26db60d..7c643ee00129 100644 --- a/src/operator/tensor/init_op.cu +++ b/src/operator/tensor/init_op.cu @@ -9,7 +9,8 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_zeros) -.set_attr("FCompute", FillCompute); +.set_attr("FCompute", FillCompute) +.set_attr("FComputeEx", FillComputeZerosEx); NNVM_REGISTER_OP(_ones) .set_attr("FCompute", FillCompute); diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 5ce132d4bebf..bc885f3cecf5 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -15,6 +15,8 @@ #include #include #include "../elemwise_op_common.h" +#include "../mxnet_op.h" + namespace mxnet { namespace op { @@ -111,7 +113,6 @@ inline bool InitType(const nnvm::NodeAttrs& attrs, return true; } - template void FillCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -127,6 +128,72 @@ void FillCompute(const nnvm::NodeAttrs& attrs, }); } +// Fill in the indices and values of a RowSparse NDArray to represent a zeros NDArray, +// instead of the usual compact representation. +template +inline void FillDnsZerosRspImpl(mshadow::Stream *s, NDArray *dst) { + using namespace rowsparse; + using namespace mshadow::expr; + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(dst->storage_type(), kRowSparseStorage); + MSHADOW_REAL_TYPE_SWITCH(dst->dtype(), DType, { + MSHADOW_INT_TYPE_SWITCH(dst->aux_type(kIdx), IType, { + auto num_rows = dst->shape()[0]; + dst->CheckAndAlloc({Shape1(num_rows)}); + auto idx = dst->aux_data(kIdx).FlatTo1D(s); + auto val = dst->data(); + Kernel::Launch(s, val.Size(), val.dptr()); + ASSIGN_DISPATCH(idx, kWriteTo, range(0, num_rows, 1, 1)) + }); + }); +} + +// Fill a rsp NDArray with zeros by updating the aux shape. +template +void FillZerosRspImpl(mshadow::Stream *s, NDArray *dst) { + if (!dst->storage_initialized()) return; + // reset the shapes if it's not zeros + auto storage_shape = dst->storage_shape(); + storage_shape[0] = 0; + dst->set_aux_shape(rowsparse::kIdx, TShape(mshadow::Shape1(0))); + dst->set_storage_shape(storage_shape); +} + +// Fill a CSR NDArray with zeros by updating the aux shape. +template +void FillZerosCsrImpl(mshadow::Stream *s, NDArray *dst) { + if (!dst->storage_initialized()) return; + // reset the shapes if it's not zeros + TShape new_shape(mshadow::Shape1(0)); + dst->set_aux_shape(csr::kIndPtr, new_shape); + dst->set_aux_shape(csr::kIdx, new_shape); + dst->set_storage_shape(new_shape); +} + +// This operator never needs to fall back, since there's no input NDArray +template +void FillComputeZerosEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(outputs.size(), 1); + CHECK_EQ(inputs.size(), 0); + auto stype = outputs[0].storage_type(); + if (stype == kRowSparseStorage) { + NDArray nd(outputs[0]); + FillZerosRspImpl(s, &nd); + } else if (stype == kCSRStorage) { + NDArray nd(outputs[0]); + FillZerosCsrImpl(s, &nd); + } else { + LOG(FATAL) << "storage type not implemented."; + } +} template void RangeCompute(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index cdc8819da18e..8ba10bfd0e27 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -10,6 +10,7 @@ #include #include #include +#include #include "../mshadow_op.h" #include "../elemwise_op_common.h" #include "../mxnet_op.h" @@ -476,6 +477,266 @@ void DotBackward_(const nnvm::NodeAttrs& attrs, } } +inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + out_attrs->at(0) = kDefaultStorage; + return true; +} + +inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), 2U); + out_attrs->at(0) = kDefaultStorage; + out_attrs->at(1) = kDefaultStorage; + return true; +} + +/*! + * \brief Kernel of dot(csr, dns1) = dns2 + * Parallelization by output matrix elements + */ +template +struct DotCsrDnsDns { + /*! + * \brief This function represents performing an inner product between a row of lhs + * and a column of rhs and then assigning the value to out[i]. + * \param i i-th element in out 1D view + * \param out output matrix + * \param data_l csr values of lhs + * \param indptr_l csr indptr of lhs + * \param col_idx_l csr col_idx of lhs + * \param data_r dense data of rhs + * \param num_cols number of columns of output + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols) { + const int irow = i / num_cols; // row id of the lhs + const int icol = i % num_cols; // col id of the rhs + DType sum = 0; + for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) { + const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs + sum += data_l[j] * data_r[cur_col*num_cols+icol]; + } + KERNEL_ASSIGN(out[i], req, sum); + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by output matrix elements + */ +template +struct DotCsrTransDnsDns { + /*! + * \brief This function represents performing an inner product between a column of lhs + * and a column of rhs and then assigning the value to out[i]. + * \param i i-th element in out 1D view + * \param out output matrix + * \param data_l csr values of lhs + * \param indptr_l csr indptr of lhs + * \param col_idx_l csr col_idx of lhs + * \param data_r dense data of rhs + * \param num_rows_l number of rows of lhs + * \param num_cols number of columns of outputs + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const int num_rows_l, + const int num_cols) { + const int irow = i / num_cols; // col id of the lhs + const int icol = i % num_cols; // col id of the rhs + DType sum = 0; + for (int k = 0; k < num_rows_l; ++k) { + const IType low = indptr_l[k]; + const IType high = indptr_l[k+1]; + if (low == high || irow < col_idx_l[low] || irow > col_idx_l[high-1]) continue; + int j = -1, l = low, r = high - 1; + while (l <= r) { + int m = l + (r - l) / 2; + if (col_idx_l[m] == irow) { + j = m; break; + } + if (col_idx_l[m] < irow) { + l = m + 1; + } else { + r = m - 1; + } + } + if (j >= 0) { + sum += data_l[j] * data_r[k*num_cols+icol]; + } + } + KERNEL_ASSIGN(out[i], req, sum); + } +}; + +/*! + * \brief Kernel of dot(csr, dns1) = dns2 + * Parallelization by row blocks + */ +struct DotCsrDnsDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const size_t seg_len, + const size_t num_rows, const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows); + for (size_t j = seg_start; j < seg_end; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_out = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto val = data_l[k]; + const size_t offset_r = col_idx_l[k] * num_cols; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by row blocks + */ +struct DotCsrTransDnsDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const size_t seg_len, + const size_t num_rows_l, const size_t num_rows, + const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (i + 1) * seg_len; + for (size_t j = 0; j < num_rows_l; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_r = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto col_idx = col_idx_l[k]; + if (col_idx < seg_start || col_idx >= seg_end) continue; + const size_t offset_out = col_idx * num_cols; + const auto val = data_l[k]; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + +template +void DotCsrDnsDnsImpl(const OpContext& ctx, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + TBlob* ret) { + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + if (!lhs.storage_initialized()) return; + + mshadow::Stream *s = ctx.get_stream(); + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob& data_r = rhs; + const TBlob data_out = *ret; + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_INT_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_INT_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + if (std::is_same::value) { // cpu parallelization by row blocks + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), seg_len, + lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); + } else { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), seg_len, + data_out.shape_[0], data_out.shape_[1]); + } + } else { // gpu parallelization by output elements + if (trans_lhs) { + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), lhs.shape()[0], + data_out.shape_[1]); + }); + } else { + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), rhs.shape_[1]); + }); + } + } + }); + }); + }); +} + +template +void DotCsrRspDnsImpl(const OpContext& ctx, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + TBlob* ret) { + CHECK_RSP_ALL_ROWS_NON_ZERO(rhs, "Dot", "rhs"); + // reuse csr dns implementation when storage_shape == shape for rhs + DotCsrDnsDnsImpl(ctx, lhs, rhs.data(), req, trans_lhs, ret); +} + +template +void DotBackwardCsrDnsDns(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const DotParam& param = nnvm::get(attrs.parsed); + TBlob ret = outputs[1].data(); + DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); +} + +template +void DotBackwardCsrRspDns(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const auto& rhs = inputs[2]; + CHECK_RSP_ALL_ROWS_NON_ZERO(rhs, "Dot", "rhs"); + // reuse csr dns implementation when storage_shape == shape for rhs + const DotParam& param = nnvm::get(attrs.parsed); + TBlob ret = outputs[1].data(); + DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); +} + inline bool DotShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { @@ -519,6 +780,68 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs, return true; } +template +void DotForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const DotParam& param = nnvm::get(attrs.parsed); + CHECK(!param.transpose_b) << "tranposing rhs of the op dot is not supported"; + auto lhs_stype = inputs[0].storage_type(); + auto rhs_stype = inputs[1].storage_type(); + auto out_stype = outputs[0].storage_type(); + if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage) { + TBlob ret = outputs[0].data(); + DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret); + } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage && + out_stype == kDefaultStorage) { + TBlob ret = outputs[0].data(); + DotCsrRspDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + } else { // TODO(junwu): add fallback + LOG(FATAL) << "Not supported dot operation for lhs.storage_type = " + << inputs[0].storage_type() << ", rhs.storage_type = " << inputs[1].storage_type() + << ", out.storage_type = " << outputs[0].storage_type(); + } +} + +template +void DotBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + CHECK_EQ(kNullOp, req[0]) + << "sparse dot does not support computing the gradient of the csr/lhs"; + CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; + + // TODO(junwu): check whether this CHECK is reasonable + const DotParam& param = nnvm::get(attrs.parsed); + CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; + auto ograd_stype = inputs[0].storage_type(); + auto lhs_stype = inputs[1].storage_type(); + auto rhs_stype = inputs[2].storage_type(); + if (ograd_stype == kDefaultStorage // ograd dns format + && lhs_stype == kCSRStorage // csr input lhs of the op + && rhs_stype == kDefaultStorage // dns input rhs of the op + && outputs[1].storage_type() == kDefaultStorage) { // grad(rhs) dns format + // dns, csr, dns => *, dns + DotBackwardCsrDnsDns(attrs, ctx, inputs, req, outputs); + } else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage && + rhs_stype == kRowSparseStorage && outputs[1].storage_type() == kDefaultStorage) { + // dns, csr, rsp => *, dns + DotBackwardCsrRspDns(attrs, ctx, inputs, req, outputs); + } else { + LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients"; + } +} + template void BatchDotForward_(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -786,6 +1109,96 @@ void Slice(const nnvm::NodeAttrs& attrs, }); } +// slice the indptr of a csr +struct SliceCsrIndPtr { + template + MSHADOW_XINLINE static void Map(int i, IType* out, const IType* in, const IType* base) { + KERNEL_ASSIGN(out[i], kWriteTo, in[i] - *base); + } +}; + +/* + * a wrapper to launch SliceCsrIndPtr kernel. + * slice [src[begin] .. src[end]) and store in dst[0, end - begin) + */ +template +void SliceCsrIndPtrImpl(const int begin, const int end, RunContext ctx, + const IType* src, IType* dst) { + using namespace mshadow; + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + int indptr_len = end - begin + 1; + Kernel::Launch(s, indptr_len, dst, src + begin, src + begin); +} + +/* + * Slice a CSR NDArray + * Only implemented for CPU + */ +template +void SliceCsrImpl(const SliceParam ¶m, const OpContext& ctx, + const NDArray &in, OpReqType req, const NDArray &out) { + using namespace mshadow; + using namespace mxnet_op; + using namespace csr; + CHECK((std::is_same::value)) << "Slice for CSR input only implemented for CPU"; + if (req == kNullOp) return; + CHECK_NE(req, kAddTo) << "kAddTo for Slice on CSR input is not supported"; + CHECK_NE(req, kWriteInplace) << "kWriteInplace for Slice on CSR input is not supported"; + Stream *s = ctx.get_stream(); + int begin = *param.begin[0]; + int end = *param.end[0]; + int indptr_len = end - begin + 1; + out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len)); + if (!in.storage_initialized()) { + out.set_aux_shape(kIndPtr, Shape1(0)); + return; + } + // assume idx indptr share the same type + MSHADOW_INT_TYPE_SWITCH(in.aux_type(kIndPtr), RType, { + MSHADOW_INT_TYPE_SWITCH(in.aux_type(kIdx), IType, { + MSHADOW_TYPE_SWITCH(in.dtype(), DType, { + auto in_indptr = in.aux_data(kIndPtr).dptr(); + auto out_indptr = out.aux_data(kIndPtr).dptr(); + SliceCsrIndPtrImpl(begin, end, ctx.run_ctx, in_indptr, out_indptr); + + // retrieve nnz (CPU implementation) + int nnz = out_indptr[indptr_len - 1]; + // copy indices and values + out.CheckAndAllocAuxData(kIdx, Shape1(nnz)); + out.CheckAndAllocData(Shape1(nnz)); + auto in_idx = in.aux_data(kIdx).dptr(); + auto out_idx = out.aux_data(kIdx).dptr(); + auto in_data = in.data().dptr(); + auto out_data = out.data().dptr(); + int offset = in_indptr[begin]; + // this is also a CPU-only implementation + memcpy(out_idx, in_idx + offset, nnz * sizeof(IType)); + memcpy(out_data, in_data + offset, nnz * sizeof(DType)); + }); + }); + }); +} + +template +void SliceEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), 1); + const SliceParam& param = nnvm::get(attrs.parsed); + auto in_stype = inputs[0].storage_type(); + CHECK_NE(in_stype, kDefaultStorage) + << "SliceEx is not expected to execute for input with default storage type"; + if (in_stype == kCSRStorage) { + SliceCsrImpl(param, ctx, inputs[0], req[0], outputs[0]); + } else { + LOG(FATAL) << "Slice not implemented for storage type" << in_stype; + } +} + inline bool SliceAssignShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index f3d69733a814..9ac998f02378 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -232,6 +232,9 @@ and ``end=(e_1, e_2, ... e_n)`` indices will result in an array with the shape The resulting array's *k*-th dimension contains elements from the *k*-th dimension of the input array with the open range ``[b_k, e_k)``. +For an input array of non-default storage type(e.g. `csr` or `row_sparse`), it only supports +slicing on the first dimension. + Example:: x = [[ 1., 2., 3., 4.], @@ -245,8 +248,10 @@ Example:: .set_attr_parser(ParamParser) .set_attr("FInferShape", SliceShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferStorageType", ElemwiseStorageType<1, 1>) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_slice"}) .set_attr("FCompute", Slice) +.set_attr("FComputeEx", SliceEx) .add_argument("data", "NDArray-or-Symbol", "Source input") .add_arguments(SliceParam::__FIELDS__()); @@ -370,7 +375,13 @@ NNVM_REGISTER_OP(dot) }) .set_attr("FInferShape", DotShape) .set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FInferStorageType", DotForwardInferStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) .set_attr("FCompute", DotForward_) +.set_attr("FComputeEx", DotForwardEx) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_dot"}) .add_argument("lhs", "NDArray-or-Symbol", "The first input") .add_argument("rhs", "NDArray-or-Symbol", "The second input") @@ -381,7 +392,13 @@ NNVM_REGISTER_OP(_backward_dot) .set_num_outputs(2) .set_attr_parser(ParamParser) .set_attr("TIsBackward", true) +.set_attr("FInferStorageType", DotBackwardInferStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) .set_attr("FCompute", DotBackward_) +.set_attr("FComputeEx", DotBackwardEx) .add_arguments(DotParam::__FIELDS__()); NNVM_REGISTER_OP(batch_dot) diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu index 96c075a7d483..2e1effb9e560 100644 --- a/src/operator/tensor/matrix_op.cu +++ b/src/operator/tensor/matrix_op.cu @@ -40,10 +40,13 @@ NNVM_REGISTER_OP(_backward_slice_axis) .set_attr("FCompute", SliceAxisGrad_); NNVM_REGISTER_OP(dot) -.set_attr("FCompute", DotForward_); +.set_attr("FCompute", DotForward_) +.set_attr("FComputeEx", DotForwardEx); NNVM_REGISTER_OP(_backward_dot) -.set_attr("FCompute", DotBackward_); +.set_attr("FCompute", DotBackward_) +.set_attr("FComputeEx", DotBackwardEx); + NNVM_REGISTER_OP(batch_dot) .set_attr("FCompute", BatchDotForward_); diff --git a/tests/ci_build/install/ubuntu_install_python.sh b/tests/ci_build/install/ubuntu_install_python.sh index 0459bb9198c4..6ac615c7ee7f 100755 --- a/tests/ci_build/install/ubuntu_install_python.sh +++ b/tests/ci_build/install/ubuntu_install_python.sh @@ -6,5 +6,5 @@ apt-get update && apt-get install -y python-dev python3-dev # the version of the pip shipped with ubuntu may be too lower, install a recent version here cd /tmp && wget https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python2 get-pip.py -pip2 install nose pylint numpy nose-timer requests -pip3 install nose pylint numpy nose-timer requests +pip2 install nose pylint numpy nose-timer requests scipy +pip3 install nose pylint numpy nose-timer requests scipy diff --git a/tests/cpp/include/test_ndarray_utils.h b/tests/cpp/include/test_ndarray_utils.h new file mode 100644 index 000000000000..4a99d2759c3b --- /dev/null +++ b/tests/cpp/include/test_ndarray_utils.h @@ -0,0 +1,115 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file test_utils.h + * \brief operator unit test utility functions + * \author Haibin Lin +*/ +#ifndef TESTS_CPP_INCLUDE_TEST_NDARRAY_UTILS_H_ +#define TESTS_CPP_INCLUDE_TEST_NDARRAY_UTILS_H_ + +/*#include +#include +#include +#include +#include +#include +#include +#include + +#include "../src/operator/tensor/elemwise_binary_op.h" +#include "../src/operator/tensor/elemwise_unary_op.h" +#include "../src/operator/optimizer_op-inl.h" +#include "../src/operator/tensor/init_op.h" + +using namespace mxnet; +#define TEST_DTYPE float +#define TEST_ITYPE int32_t + +void CheckDataRegion(const TBlob &src, const TBlob &dst) { + auto size = src.shape_.Size() * mshadow::mshadow_sizeof(src.type_flag_); + auto equals = memcmp(src.dptr_, dst.dptr_, size); + EXPECT_EQ(equals, 0); +} + +float RandFloat() { + float v = rand() * 1.0 / RAND_MAX; + return v; +} + +// Get an NDArray with provided indices, prepared for a RowSparse NDArray. +NDArray RspIdxND(const TShape shape, const Context ctx, const std::vector &values) { + NDArray nd(shape, ctx, false, ROW_SPARSE_IDX_TYPE); + size_t num_val = values.size(); + MSHADOW_TYPE_SWITCH(nd.dtype(), DType, { + auto tensor = nd.data().FlatTo1D(); + for (size_t i = 0; i < num_val; i++) { + tensor[i] = values[i]; + } + }); + return nd; +} + +// Get a dense NDArray with provided values. +NDArray DnsND(const TShape shape, const Context ctx, std::vector vs) { + NDArray nd(shape, ctx, false); + size_t num_val = shape.Size(); + // generate random values + while (vs.size() < num_val) { + auto v = RandFloat(); + vs.push_back(v); + } + CHECK_EQ(vs.size(), nd.shape().Size()); + MSHADOW_TYPE_SWITCH(nd.dtype(), DType, { + auto tensor = nd.data().FlatTo1D(); + for (size_t i = 0; i < num_val; i++) { + tensor[i] = vs[i]; + } + }); + return nd; +} + +// Get a RowSparse NDArray with provided indices and values +NDArray RspND(const TShape shape, const Context ctx, const std::vector idx, + std::vector vals) { + CHECK(shape.ndim() <= 2) << "High dimensional row sparse not implemented yet"; + index_t num_rows = idx.size(); + index_t num_cols = vals.size() / idx.size(); + // create index NDArray + NDArray index = RspIdxND(mshadow::Shape1(num_rows), ctx, idx); + CHECK_EQ(vals.size() % idx.size(), 0); + // create value NDArray + NDArray data = DnsND(mshadow::Shape2(num_rows, num_cols), ctx, vals); + // create result nd + NDArray nd(kRowSparseStorage, shape, ctx, false, mshadow::default_type_flag, + {}, {mshadow::Shape1(num_rows)}); + // assign values + NDArray nd_aux = nd.aux_ndarray(0); + NDArray nd_data = nd.data_ndarray(); + CopyFromTo(index, &nd_aux); + CopyFromTo(data, &nd_data); + return nd; +} + +// TODO(haibin) support other types +NDArray Convert(NDArrayStorageType type, NDArray src) { + CHECK_EQ(type, kDefaultStorage); + NDArray converted(src.shape(), src.ctx(), false); + Engine::Get()->PushSync([src, converted](RunContext ctx) { + // TODO provide type in attrs, which is empty now + OpContext op_ctx; + op_ctx.run_ctx = ctx; + if (src.storage_type() == kRowSparseStorage) { + std::vector inputs({src}), outputs({converted}); + op::CastStorageComputeEx({}, op_ctx, inputs, {}, outputs); + } else if (src.storage_type() == kDefaultStorage) { + std::vector inputs({src.data()}), outputs({converted.data()}); + op::IdentityCompute({}, op_ctx, inputs, {kWriteTo}, outputs); + } else { + LOG(FATAL) << "unsupported storage type"; + } + }, src.ctx(), {src.var()}, {converted.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + converted.WaitToRead(); + return converted; +}*/ +#endif // TESTS_CPP_INCLUDE_TEST_NDARRAY_UTILS_H_ diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc index 719980b5d4f5..32d60cf3e4e4 100644 --- a/tests/cpp/operator/batchnorm_test.cc +++ b/tests/cpp/operator/batchnorm_test.cc @@ -1,7 +1,7 @@ /*! * Copyright (c) 2017 by Contributors * \file batchnorm_test.cc - * \brief operator unit test utility functions + * \brief batchnorm operator unit test utility functions * \author Chris Olivier */ @@ -874,8 +874,8 @@ TEST(BATCH_NORM, TestIterAll) { kwargs.push_back({ "cudnn_off", "True" }); } for (TShape shape : shapes) { - for (int g1 = 0; g1 < 2U; ++g1) { - for (int g2 = 0; g2 < 2U; ++g2) { + for (int g1 = 0; g1 < 2; ++g1) { + for (int g2 = 0; g2 < 2; ++g2) { for (int type : v2_types) { MSHADOW_REAL_TYPE_SWITCH_EX( type, DType, AccReal, diff --git a/tests/cpp/operator/ndarray_test.cc b/tests/cpp/operator/ndarray_test.cc new file mode 100644 index 000000000000..f2ed30793881 --- /dev/null +++ b/tests/cpp/operator/ndarray_test.cc @@ -0,0 +1,6 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file ndarray_test.cc + * \brief ndarray unit test utility functions + * \author Haibin Lin +*/ diff --git a/tests/cpp/unittest.mk b/tests/cpp/unittest.mk index 808b655e9dba..ec7bb55ec983 100644 --- a/tests/cpp/unittest.mk +++ b/tests/cpp/unittest.mk @@ -47,4 +47,4 @@ testclean: -include build/tests/cpp/*.d -include build/tests/cpp/operator/*.d -include build/tests/cpp/storage/*.d --include build/tests/cpp/engine/*.d \ No newline at end of file +-include build/tests/cpp/engine/*.d diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py index ebed6c57586d..c30aaed13a7a 100644 --- a/tests/nightly/dist_sync_kvstore.py +++ b/tests/nightly/dist_sync_kvstore.py @@ -11,38 +11,89 @@ def check_diff_to_scalar(A, x): assert(np.sum(np.abs((A - x).asnumpy())) == 0), A.asnumpy() # setup -keys = [3, 5, 7] +keys = ['3', '5', '7'] +rsp_keys = ['9', '11', '13'] + rate = 2 shape = (2, 2) big_shape = (1200, 1200) # big than BIGARRAY_BOUND -kv = mx.kv.create('dist_sync') - -# init kv -kv.init(keys, [mx.nd.ones(shape)] * len(keys)) -kv.init(99, mx.nd.ones(big_shape)) -# init updater on servers -kv.set_optimizer(mx.optimizer.create('test', rate)) +def init_kv(): + kv = mx.kv.create('dist_sync') + # init kv + kv.init(keys, [mx.nd.ones(shape)] * len(keys)) + kv.init('99', mx.nd.ones(big_shape)) + my_rank = kv.rank + nworker = kv.num_workers + # init updater on servers + kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate)) + return kv, my_rank, nworker -my_rank = kv.rank -nworker = kv.num_workers +def init_kv_rsp(): + kv = mx.kv.create('dist_sync') + # init kv + kv.init(rsp_keys, [mx.nd.ones(shape)._to_rsp()] * len(rsp_keys)) + # kv.init(99, mx.nd.ones(big_shape)) + my_rank = kv.rank + nworker = kv.num_workers + # init updater on servers + kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate)) + return kv, my_rank, nworker def test_sync_push_pull(): + kv, my_rank, nworker = init_kv() nrepeat = 3 for i in range(nrepeat): - kv.push(3, mx.nd.ones(shape)*(my_rank+1)) - kv.push(99, mx.nd.ones(big_shape)*(my_rank+1)) + kv.push('3', mx.nd.ones(shape)*(my_rank+1)) + kv.push('99', mx.nd.ones(big_shape)*(my_rank+1)) num = (nworker + 1 ) * nworker * rate / 2 * nrepeat + 1 val = mx.nd.zeros(shape) - kv.pull(3, out = val) + kv.pull('3', out = val) check_diff_to_scalar(val, num) - # print val.asnumpy() val2 = mx.nd.zeros(big_shape) - kv.pull(99, out = val2) + kv.pull('99', out = val2) check_diff_to_scalar(val2, num) + print('done') + +def test_sync_push_pull_row_sparse(): + kv, my_rank, nworker = init_kv_rsp() + nrepeat = 2 + + v = mx.nd.zeros(shape) + my_row = my_rank % shape[0] + for col in range(shape[1]): + v[my_row][col] = my_rank + 1 + + for i in range(nrepeat): + kv.push('9', v._to_rsp()) + # kv.push(99, mx.nd.ones(big_shape)*(my_rank+1)) + + # pull a subset of rows this worker is interested in + val = v.copyto(mx.cpu())._to_rsp() + kv.pull('9', out = val) + + expected = mx.nd.zeros(shape) + # initial value + for col in range(shape[1]): + expected[my_row][col] = 1 + # apply updates from workers + for rank in range(nworker): + row = rank % shape[0] + if row != my_row: + continue + for col in range(shape[1]): + expected[my_row][col] += (rank + 1) * rate * nrepeat + #print("expect ", expected.asnumpy()) + + check_diff_to_scalar(val, expected) + # print('done') + #val2 = mx.nd.zeros(big_shape) + #kv.pull(99, out = val2) + #check_diff_to_scalar(val2, num) if __name__ == "__main__": test_sync_push_pull() + test_sync_push_pull_row_sparse() diff --git a/tests/python/unittest/test_infer_shape.py b/tests/python/unittest/test_infer_shape.py index 35598bc55be8..9188dd9d933f 100644 --- a/tests/python/unittest/test_infer_shape.py +++ b/tests/python/unittest/test_infer_shape.py @@ -112,6 +112,37 @@ def test_incomplete_infer_concat(): assert arg_shapes['b'] == (2, 5) assert arg_shapes['d'] == (2, 15) +def test_fc_infer_type(): + mx_real_t = mx.base.mx_real_t + data = mx.symbol.Variable('data') + out = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=1000) + + # infer type + data_type = mx_real_t + arg_types, out_types, aux_types = out.infer_type(data=data_type) + arg_type_dict = dict(zip(out.list_arguments(), arg_types)) + assert len(out_types) == 1 + assert out_types[0] == mx_real_t + true_types = { + 'fc1_bias' : mx_real_t, + 'fc1_weight' : mx_real_t } + for k, v in true_types.items(): + assert arg_type_dict[k] == v + +def check_infer_storage(v1, v2, v1_storage, v2_storage, out_chunk): + out = mx.symbol.elemwise_add(v1, v2) + arg_storage_types, out_storage_types, aux_storage_types = out.infer_storage_type(v1=v1_storage, v2=v2_storage) + assert len(out_storage_types) == 1 + assert out_storage_types[0] == out_chunk + +def test_elemwise_add_infer_storage_type(): + v1 = mx.symbol.Variable('v1') + v2 = mx.symbol.Variable('v2') + check_infer_storage(v1, v2, 'default', 'default', 'default') + check_infer_storage(v1, v2, 'default', 'row_sparse', 'default') + check_infer_storage(v1, v2, 'row_sparse', 'default', 'default') + check_infer_storage(v1, v2, 'row_sparse', 'row_sparse', 'row_sparse') + if __name__ == "__main__": test_mlp2_infer_shape() test_mlp2_infer_error() @@ -121,3 +152,4 @@ def test_incomplete_infer_concat(): test_incomplete_infer_slicechannel() test_incomplete_infer_convolution() test_incomplete_infer_concat() + test_elemwise_add_infer_storage_type() diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 5fe61b185041..4cbb4f19e40a 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -1,5 +1,6 @@ # pylint: skip-file import mxnet as mx +from mxnet.test_utils import * import numpy as np import os, gzip import pickle as pickle @@ -88,7 +89,43 @@ def test_NDArrayIter(): else: assert(labelcount[i] == 100) +''' +def test_libsvm(): + #TODO(haibin) automatic the test instead of hard coded test + cwd = os.getcwd() + data_path = os.path.join(cwd, 'data.t') + label_path = os.path.join(cwd, 'label.t') + with open(data_path, 'w') as fout: + fout.write('1.0 0:0.5 2:1.2\n') + fout.write('-2.0\n') + fout.write('-3.0 0:0.6 1:2.4 2:1.2\n') + fout.write('4 2:-1.2\n') + + with open(label_path, 'w') as fout: + fout.write('1.0\n') + fout.write('-2.0 0:0.125\n') + fout.write('-3.0 2:1.2\n') + fout.write('4 1:1.0 2:-1.2\n') + + data_dir = os.path.join(os.getcwd(), 'data') + f = (data_path, label_path, (3,), (3,), 3) + data_train = mx.io.LibSVMIter(data_libsvm=f[0], + label_libsvm=f[1], + data_shape=f[2], + label_shape=f[3], + batch_size=f[4]) + + first = mx.nd.array([[ 0.5, 0., 1.2], [ 0., 0., 0.], [ 0.6, 2.4, 1.2]]) + second = mx.nd.array([[ 0., 0., -1.2], [ 0.5, 0., 1.2], [ 0., 0., 0.]]) + i = 0 + for batch in iter(data_train): + expected = first.asnumpy() if i == 0 else second.asnumpy() + assert_almost_equal(data_train.getdata().asnumpy(), expected) + i += 1 +''' + if __name__ == "__main__": test_NDArrayIter() test_MNISTIter() test_Cifar10Rec() + # test_libsvm() diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index dd8149d4822e..bd12f95b2496 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -1,115 +1,184 @@ # pylint: skip-file import mxnet as mx import numpy as np +from mxnet.test_utils import rand_ndarray, assert_almost_equal shape = (4, 4) keys = [5, 7, 11] -def init_kv(): +str_keys = ['b', 'c', 'd'] + +def init_kv(stype='default'): """init kv """ kv = mx.kv.create() # single - kv.init(3, mx.nd.zeros(shape)) + kv.init(3, mx.nd.zeros(shape=shape, storage_type=stype)) # list - kv.init(keys, [mx.nd.zeros(shape)] * len(keys)) + kv.init(keys, [mx.nd.zeros(shape=shape, storage_type=stype)] * len(keys)) return kv +def init_kv_with_str(): + """init kv """ + kv = mx.kv.create() + # single + kv.init('a', mx.nd.zeros(shape)) + # list + kv.init(str_keys, [mx.nd.zeros(shape)] * len(keys)) + return kv def check_diff_to_scalar(A, x): """ assert A == x""" assert(np.sum(np.abs((A - x).asnumpy())) == 0) + def test_single_kv_pair(): """single key-value pair push & pull""" + def check_single_kv_pair(kv, key): + kv.push(key, mx.nd.ones(shape)) + val = mx.nd.empty(shape) + kv.pull(key, out = val) + check_diff_to_scalar(val, 1) + + check_single_kv_pair(init_kv(), 3) + check_single_kv_pair(init_kv_with_str(), 'a') - kv = init_kv() - kv.push(3, mx.nd.ones(shape)) - val = mx.nd.empty(shape) - kv.pull(3, out = val) - check_diff_to_scalar(val, 1) def test_init(): """test init""" - kv = mx.kv.create() - kv.init(3, mx.nd.ones(shape)*4) - a = mx.nd.zeros(shape) - kv.pull(3, out=a) - check_diff_to_scalar(a, 4) + def check_init(kv, key): + kv.init(key, mx.nd.ones(shape)*4) + a = mx.nd.zeros(shape) + kv.pull(key, out=a) + check_diff_to_scalar(a, 4) + + check_init(mx.kv.create(), 3) + check_init(mx.kv.create(), 'a') + def test_list_kv_pair(): """list key-value pair push & pull""" + def check_list_kv_pair(kv, key): + kv.push(key, [mx.nd.ones(shape)*4] * len(key)) + val = [mx.nd.empty(shape)] * len(key) + kv.pull(key, out = val) + for v in val: + check_diff_to_scalar(v, 4) - kv = init_kv() - - kv.push(keys, [mx.nd.ones(shape)*4] * len(keys)) - val = [mx.nd.empty(shape)] * len(keys) - kv.pull(keys, out = val) - for v in val: - check_diff_to_scalar(v, 4) + check_list_kv_pair(init_kv(), keys) + check_list_kv_pair(init_kv_with_str(), str_keys) def test_aggregator(): """aggregate value on muliple devices""" - kv = init_kv() + def check_aggregator(kv, key, key_list): + # devices + num_devs = 4 + devs = [mx.Context('cpu', i) for i in range(num_devs)] + + # single + vals = [mx.nd.ones(shape, d) for d in devs] + + kv.push(key, vals) + kv.pull(key, out = vals) + + for v in vals: + check_diff_to_scalar(v, num_devs) + + # list + vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(key_list) + kv.push(key_list, vals) + kv.pull(key_list, out = vals) + + for vv in vals: + for v in vv: + check_diff_to_scalar(v, num_devs * 2.0) + + check_aggregator(init_kv(), 3, keys) + check_aggregator(init_kv_with_str(), 'a', str_keys) + + +def test_sparse_aggregator(): + """aggregate sparse ndarray on muliple devices""" + + stype = 'row_sparse' + kv = init_kv(stype) # devices num_devs = 4 devs = [mx.Context('cpu', i) for i in range(num_devs)] # single - vals = [mx.nd.ones(shape, d) for d in devs] + vals = [rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)] + expected_sum = np.zeros(shape) + for v in vals: + expected_sum += v.asnumpy() kv.push(3, vals) kv.pull(3, out = vals) - + result_sum = np.zeros(shape) for v in vals: - check_diff_to_scalar(v, num_devs) + result_sum += v.asnumpy() + assert_almost_equal(result_sum, expected_sum * num_devs) # list - vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(keys) + vals = [[rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)]] * len(keys) + expected_sum = np.zeros(shape) + for v in vals[0]: + expected_sum += v.asnumpy() + kv.push(keys, vals) kv.pull(keys, out = vals) - for vv in vals: + result_sum = np.zeros(shape) for v in vv: - check_diff_to_scalar(v, num_devs * 2.0) + result_sum += v.asnumpy() + assert_almost_equal(result_sum, expected_sum * num_devs) def updater(key, recv, local): """use updater: +=""" local += recv + def test_updater(dev = 'cpu'): """updater""" - kv = init_kv() - kv._set_updater(updater) + def check_updater(kv, key, key_list): + # devices + num_devs = 4 + devs = [mx.Context(dev, i) for i in range(num_devs)] - # devices - num_devs = 4 - devs = [mx.Context(dev, i) for i in range(num_devs)] + # single + vals = [mx.nd.ones(shape, d) for d in devs] - # single - vals = [mx.nd.ones(shape, d) for d in devs] + kv.push(key, vals) + kv.pull(key, out = vals) - kv.push(3, vals) - kv.pull(3, out = vals) + for v in vals: + check_diff_to_scalar(v, num_devs) - for v in vals: - check_diff_to_scalar(v, num_devs) + # list + vals = [[mx.nd.ones(shape, d) for d in devs]] * len(key_list) - # list - vals = [[mx.nd.ones(shape, d) for d in devs]] * len(keys) + num_push = 4 + for i in range(num_push): + kv.push(key_list, vals) - num_push = 4 - for i in range(num_push): - kv.push(keys, vals) + kv.pull(key_list, out = vals) + + for vv in vals: + for v in vv: + check_diff_to_scalar(v, num_devs * num_push) + + kv = init_kv() + kv._set_updater(updater) + check_updater(kv, 3, keys) + + str_kv = init_kv_with_str() + str_kv._set_updater(updater) + check_updater(str_kv, 'a', str_keys) - kv.pull(keys, out = vals) - for vv in vals: - for v in vv: - check_diff_to_scalar(v, num_devs * num_push) def test_get_type(): kvtype = 'local_allreduce_cpu' @@ -121,5 +190,6 @@ def test_get_type(): test_get_type() test_single_kv_pair() test_list_kv_pair() + test_sparse_aggregator() test_aggregator() test_updater() diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 9f3cff8e1265..dcc0f38b208a 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -1,9 +1,11 @@ -import mxnet as mx import mxnet.ndarray as nd +from mxnet.test_utils import * import numpy as np from functools import reduce from mxnet.module.executor_group import DataParallelExecutorGroup +import numpy.random as rnd + def test_module_dtype(): dtype = np.float16 @@ -262,7 +264,6 @@ def mean_abs(x): break assert(mon_result_counts == [2, 2, 1, 6, 6, 4]) - def test_executor_group(): def get_rnn_sym(num_layers, num_words, num_hidden, num_embed, seq_len): stack = mx.rnn.SequentialRNNCell() @@ -374,6 +375,96 @@ def test_shared_exec_group(exec_grp_shared, exec_grp_created, shared_arg_names=N test_shared_exec_group(exec_grp_shared=exec_group1, exec_grp_created=exec_group2, shared_arg_names=shared_arg_names, extra_args=extra_args) +def test_module_fm(): + mx.random.seed(11) + rnd.seed(11) + def fm_model(k, feature_dim): + norm = mx.initializer.Normal(sigma=0.01) + x = mx.symbol.Variable("data", storage_type='csr') + v = mx.symbol.Variable("v", shape=(feature_dim, k), init=norm, storage_type='row_sparse') + + w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=norm) + w1 = mx.symbol.dot(x, w1_weight) + + v_s = mx.symbol.sum(data=mx.symbol.square(data=v), axis=1) + x_s = mx.symbol.square(data=x) + bd = 0.5 * mx.symbol.negative(data=mx.symbol.broadcast_mul(x_s, v_s)) + + w2 = mx.symbol.dot(x, v) + w2_squared = 0.5 * mx.symbol.square(data=w2) + + w_all = mx.symbol.Concat(w1, w2_squared, bd, dim=1) + model = mx.symbol.sum(data=w_all, axis=1, keepdims=True) + y = mx.symbol.Variable("out_label") + model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out") + return model + + # model + ctx = default_context() + k = 5 + feature_dim = 20 + model = fm_model(k, feature_dim) + + # data iter + num_batches = 8 + batch_size = 25 + num_samples = batch_size * num_batches + import scipy.sparse as sp + # generate some random scipy csr data + csr_sp = sp.rand(num_samples, feature_dim, density=0.5, format='csr') + csr_nd = mx.sparse_nd.csr(csr_sp.data, csr_sp.indptr, csr_sp.indices, + (num_samples, feature_dim)) + label = mx.nd.ones((num_samples,1)) + # the alternative is to use LibSVMIter + train_iter = mx.io.NDArrayIter(data=csr_nd, + label={'out_label':label}, + batch_size=batch_size) + # create module + mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['out_label']) + # allocate memory by given the input data and lable shapes + mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) + # initialize parameters by uniform random numbers + mod.init_params(initializer=mx.init.Uniform(scale=.1)) + # use Sparse SGD with learning rate 0.1 to train + mod.init_optimizer(optimizer='sgd') + # use accuracy as the metric + metric = mx.metric.create('MSE') + # train 10 epoch + for epoch in range(10): + train_iter.reset() + metric.reset() + for batch in train_iter: + mod.forward(batch, is_train=True) # compute predictions + mod.update_metric(metric, batch.label) # accumulate prediction accuracy + mod.backward() # compute gradients + mod.update() # update parameters + # print('Epoch %d, Training %s' % (epoch, metric.get())) + assert(metric.get()[1] < 0.2) + +def test_module_initializer(): + def regression_model(m): + x = mx.symbol.var("data", storage_type='csr') + v = mx.symbol.var("v", shape=(m, 1), init=mx.init.Uniform(scale=.1), + storage_type='row_sparse') + model = mx.symbol.dot(lhs=x, rhs=v) + y = mx.symbol.Variable("label") + model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out") + return model + + n, m = 128, 100 + model = regression_model(m) + + data = mx.nd.zeros(shape=(n, m), storage_type='csr') + label = mx.nd.zeros((n, 1)) + iterator = mx.io.NDArrayIter(data=data, label={'label':label}, batch_size=n) + + # create module + mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['label']) + mod.bind(data_shapes=iterator.provide_data, label_shapes=iterator.provide_label) + mod.init_params() + v = mod._arg_params['v'] + assert(v.storage_type == 'row_sparse') + assert(np.sum(v.asnumpy()) != 0) if __name__ == '__main__': test_module_dtype() @@ -385,3 +476,5 @@ def test_shared_exec_group(exec_grp_shared, exec_grp_created, shared_arg_names=N test_module_switch_bucket() test_monitor() test_executor_group() + test_module_fm() + test_module_initializer() diff --git a/tests/python/unittest/test_multi_device_exec.py b/tests/python/unittest/test_multi_device_exec.py index 8956c4edebac..3293ae2b0abc 100644 --- a/tests/python/unittest/test_multi_device_exec.py +++ b/tests/python/unittest/test_multi_device_exec.py @@ -1,4 +1,5 @@ import os +import numpy as np import mxnet as mx def test_ctx_group(): @@ -32,5 +33,35 @@ def test_ctx_group(): else: assert arr.context == group2ctx['stage2'] +def check_ctx_group_sparse(lhs_stype, rhs_stype): + with mx.AttrScope(ctx_group='stage1'): + lhs = mx.symbol.Variable('lhs', storage_type=lhs_stype) + rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype) + plus = mx.symbol.elemwise_add(lhs, rhs, name='plus') + + set_stage1 = set(plus.list_arguments()) + with mx.AttrScope(ctx_group='stage2'): + softmax = mx.symbol.SoftmaxOutput(data = plus, name = 'softmax') + + set_stage2 = set(softmax.list_arguments()) - set_stage1 + + group2ctx = { + 'stage1' : mx.cpu(1), + 'stage2' : mx.cpu(2) + } + texec = softmax.simple_bind(mx.cpu(0), group2ctx=group2ctx, lhs=(1,200), rhs=(1,200)) + + for arr, name in zip(texec.arg_arrays, softmax.list_arguments()): + if name in set_stage1: + assert arr.context == group2ctx['stage1'] + else: + assert arr.context == group2ctx['stage2'] + +def test_ctx_group_sparse(): + check_ctx_group_sparse('default', 'default') + check_ctx_group_sparse('default', 'row_sparse') + check_ctx_group_sparse('row_sparse', 'row_sparse') + if __name__ == '__main__': test_ctx_group() + test_ctx_group_sparse() diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index dd38bdf98606..adf93a98f26f 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -330,6 +330,7 @@ def test_dot(): assert_almost_equal(c, C.asnumpy()) + def test_reduce(): sample_num = 200 def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 924ef351dbe5..6a1f8cfd8199 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5,6 +5,7 @@ from numpy.testing import assert_allclose from mxnet.test_utils import * + def np_softmax(x, axis=-1): # fix for old numpy on Travis not supporting keepdims # x = x - np.max(x, axis=-1, keepdims=True) @@ -39,6 +40,7 @@ def check_elementwise_sum_with_shape(shape, n): for a in arr_grad: assert_almost_equal(a.asnumpy(), out_grad.asnumpy()) + def test_elementwise_sum(): np.random.seed(0) nrepeat = 2 @@ -93,6 +95,7 @@ def check_concat_with_shape(shapes, dimension, skip_second): np_grad = arr_np[i] assert_almost_equal(grad.asnumpy(), np_grad + 1) + def test_concat(): for dimension in range(4): n = 2 @@ -139,6 +142,7 @@ def test_concat(): check_concat_with_shape(shapes,dimension,True) check_concat_with_shape(shapes,dimension,False) + def test_slice_channel(): def check_slice_channel(data_ndim, axis, num_outputs, squeeze_axis): ins = [] @@ -202,6 +206,7 @@ def check_regression(symbol, forward, backward): npout = backward(npout, arr_label.asnumpy().reshape(npout.shape)) assert_almost_equal(npout, arr_grad.asnumpy()) + def test_regression(): check_regression(mx.symbol.LogisticRegressionOutput, lambda x: 1.0 / (1.0 + np.exp(-x)), @@ -210,6 +215,7 @@ def test_regression(): lambda x: x, lambda x, y : x - y) + def check_softmax_with_ignore_label(xpu): X = mx.symbol.Variable('X') L = mx.symbol.Variable('L') @@ -242,6 +248,7 @@ def check_softmax_with_ignore_label(xpu): assert abs(np.sum(grad1[:int(shape[0]/2)])) < 1e-5 assert_almost_equal(grad0[int(shape[0]/2):], grad1[int(shape[0]/2):]) + def check_softmax_with_shape(shape, xpu, preserve_shape=False): # bind with label X = mx.symbol.Variable('X') @@ -258,11 +265,13 @@ def check_softmax_with_shape(shape, xpu, preserve_shape=False): exec1.backward() assert_almost_equal(grad.asnumpy(), np_softmax(x.asnumpy()) - l.asnumpy(), rtol=1e-4) + def test_softmax(): check_softmax_with_shape((3, 4), default_context(), preserve_shape=False) check_softmax_with_shape((3, 4), default_context(), preserve_shape=True) check_softmax_with_shape((3, 4, 2), default_context(), preserve_shape=True) + def test_python_op(): X = mx.symbol.Variable('X') op = mx.operator.NumpyOp() @@ -277,6 +286,7 @@ def test_python_op(): exec1.backward(dy) assert_almost_equal(dy.asnumpy(), dx.asnumpy()) + def test_swapaxes(): data = mx.symbol.Variable('data') shape = (2, 3, 4) @@ -295,6 +305,7 @@ def test_swapaxes(): assert_almost_equal(out, swap_) + def test_scalarop(): data = mx.symbol.Variable('data') shape = (3, 4) @@ -325,6 +336,7 @@ def test_scalar_pow(): check_symbolic_forward(test, [data_tmp], [data_tmp ** 2]) check_symbolic_backward(test, [data_tmp], [np.ones(shape)], [2 * data_tmp]) + def test_symbol_pow(): shape = (1, 1) @@ -343,6 +355,7 @@ def test_symbol_pow(): exp_dir = data_tmp**(exp_tmp) * np.log(data_tmp) check_symbolic_backward(test, [data_tmp, exp_tmp], [np.ones(shape)], [data_dir, exp_dir]) + def test_pow_fn(): shape = (3, 4) exp = mx.symbol.Variable("exp") @@ -352,6 +365,7 @@ def test_pow_fn(): check_symbolic_forward(y, [x], [2**x]) check_symbolic_backward(y, [x], [np.ones(shape)], [np.log(2) * 2**x]) + def test_relu(): def frelu(x): return np.maximum(x, 0.0) @@ -367,6 +381,7 @@ def frelu_grad(x): check_symbolic_forward(y, [xa], [ya]) check_symbolic_backward(y, [xa], [np.ones(shape)], [ga]) + def test_sigmoid(): def fsigmoid(a): return np.divide(1.0, (1.0 + np.exp(-a))) @@ -379,6 +394,7 @@ def fsigmoid(a): check_symbolic_forward(y, [xa], [ya]) check_symbolic_backward(y, [xa], [np.ones(shape)], [ya * (1 - ya)]) + def test_binary_logic(): def _inner_test(forward_gt, logic_sym, x_shape, y_shape, test_scalar=True): x = mx.symbol.Variable("x") @@ -434,6 +450,7 @@ def _inner_test(forward_gt, logic_sym, x_shape, y_shape, test_scalar=True): logic_sym=lambda x, y: mx.sym.broadcast_not_equal(x, y), x_shape=(1, 10), y_shape=(10, 1), test_scalar=False) + def test_embedding(): in_dim = 10 out_dim = 4 @@ -460,6 +477,7 @@ def test_embedding(): exe_test.backward([grad]) assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, np_grad)) + # check ops handle duplicate input correctly. def test_binary_op_duplicate_input(): data = mx.symbol.Variable('data') @@ -478,6 +496,7 @@ def test_binary_op_duplicate_input(): exe_square.backward(out_grad) assert_almost_equal(arr_grad.asnumpy(), 2.0 * data_tmp) + def test_sign(): data = mx.symbol.Variable('data') shape = (3, 4) @@ -501,6 +520,7 @@ def test_sign(): exe_test.backward(out_grad) assert_almost_equal(arr_grad.asnumpy(), npout_grad) + def test_round_ceil_floor(): data = mx.symbol.Variable('data') shape = (3, 4) @@ -517,6 +537,7 @@ def test_round_ceil_floor(): npout = np.round(data_tmp) + np.ceil(data_tmp) + np.floor(data_tmp) assert_almost_equal(out, npout) + def test_rsqrt_cos_sin(): data = mx.symbol.Variable('data') shape = (3, 4) @@ -540,6 +561,7 @@ def test_rsqrt_cos_sin(): exe_test.backward(out_grad) assert_almost_equal(arr_grad.asnumpy(), npout_grad) + def test_maximum_minimum(): data1 = mx.symbol.Variable('data') data2 = mx.symbol.Variable('data') @@ -578,6 +600,7 @@ def test_maximum_minimum(): assert_almost_equal(arr_grad1.asnumpy(), npout_grad1) assert_almost_equal(arr_grad2.asnumpy(), npout_grad2) + def test_maximum_minimum_scalar(): data1 = mx.symbol.Variable('data') shape = (3, 4) @@ -608,6 +631,7 @@ def test_maximum_minimum_scalar(): assert_almost_equal(arr_grad1.asnumpy(), npout_grad1) + def test_abs(): data = mx.symbol.Variable('data') shape = (3, 4) @@ -631,6 +655,7 @@ def test_abs(): exe_test.backward(out_grad) assert_almost_equal(arr_grad.asnumpy(), npout_grad) + def check_deconvolution_forward_backward(input_shape, num_filter, kernel, stride, pad): """configure A: input --> conv --> deconv --> output. the convolution and deconvoluiton has similar parameter which ensure @@ -729,6 +754,7 @@ def check_deconvolution_gradient(input_shape, num_filter, pad): assert_almost_equal(conv_args_grad[1].asnumpy() + deconv_addto_args_grad_npy[1], deconv_addto_args_grad[1].asnumpy(), rtol=1e-3, atol=1e-2) + def check_deconvolution_target_shape(input_shape, kernel, stride, pad, adj, target_shape=None): data = mx.sym.Variable(name="data") if target_shape: @@ -742,6 +768,7 @@ def check_deconvolution_target_shape(input_shape, kernel, stride, pad, adj, targ arg_shapes, out_shapes, _ = deconv.infer_shape(data=input_shape) assert out_shapes[0] == (input_shape[0], 5, 8, 8) + def test_deconvolution(): check_deconvolution_target_shape( input_shape = (2,3,4,4), @@ -790,6 +817,7 @@ def test_deconvolution(): pad = (3,3) ) + def check_nearest_upsampling_with_shape(shapes, scale, root_scale): arr = {'arg_%d'%i: mx.random.uniform(-10.0, 10.0, shape, ctx=mx.cpu()).copyto(default_context()) for i, shape in zip(range(len(shapes)), shapes)} arr_grad = {'arg_%d'%i: mx.nd.zeros(shape) for i, shape in zip(range(len(shapes)), shapes)} @@ -802,6 +830,7 @@ def check_nearest_upsampling_with_shape(shapes, scale, root_scale): name = 'arg_%d'%k assert_allclose(arr[name].asnumpy()*root_scale**2*scale**(2*k), arr_grad[name].asnumpy(), rtol=1e-4) + def check_bilinear_upsampling_with_shape(shapes, scale, root_scale): arr = {'arg_%d'%i: mx.random.uniform(-10.0, 10.0, shape, ctx=mx.cpu()).copyto(default_context()) for i, shape in zip(range(len(shapes)), shapes)} arr_grad = {'arg_%d'%i: mx.nd.zeros(shape) for i, shape in zip(range(len(shapes)), shapes)} @@ -814,6 +843,7 @@ def check_bilinear_upsampling_with_shape(shapes, scale, root_scale): name = 'arg_%d'%k assert_allclose(arr[name].asnumpy()*root_scale**2*scale**(2*k), arr_grad[name].asnumpy(), rtol=1e-4) + def test_nearest_upsampling(): for root_scale in [1,2,3]: for scale in [1,2,3]: @@ -822,6 +852,7 @@ def test_nearest_upsampling(): shapes = [(1,3,base*root_scale*scale**(num_shape-1-i),base*root_scale*scale**(num_shape-1-i)) for i in range(num_shape)] check_nearest_upsampling_with_shape(shapes, scale, root_scale) + def test_batchnorm_training(): for shape in [(2, 3), (2, 3, 2, 2)]: data_tmp = np.random.normal(-0.1, 0.1, size=shape) @@ -893,6 +924,7 @@ def test_batchnorm_training(): test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis) check_numeric_gradient(test, [data_tmp, gamma, beta], [xrolling_mean, xrolling_std], numeric_eps=1e-2, rtol=0.2) + def test_convolution_grouping(): num_filter = 4 num_group = 2 @@ -923,6 +955,7 @@ def test_convolution_grouping(): for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays): np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3, atol=1e-4) + def gen_broadcast_data(idx): # Manually set test cases binary_op_data_shape = np.array( @@ -978,15 +1011,18 @@ def gen_broadcast_data(idx): r_shape[np.where(r_axis_flags == 0)] = 1 return [np.random.random(l_shape), np.random.random(r_shape)] + def gen_broadcast_data_int(idx): d = gen_broadcast_data(idx); return [np.round(d[0]*100), np.round(d[1]*100)] + def gen_binary_data(dummy): ndim = np.random.randint(1, 6) shape = np.random.randint(1, 6, size=(ndim,)) return [np.random.random(shape), np.random.random(shape)] + def check_binary_op_forward(symbol, baseline, gen_data): sample_num = 200 for i in range(sample_num): @@ -996,6 +1032,7 @@ def check_binary_op_forward(symbol, baseline, gen_data): y.forward(is_train=True) assert_allclose(x, y.outputs[0].asnumpy(), rtol=1e-3, atol=1e-5) + def check_binary_op_backward(symbol, baseline, gen_data): sample_num = 200 for i in range(sample_num): @@ -1022,6 +1059,7 @@ def reduce_op(shape, x): assert_allclose(x_1, y_1.asnumpy(), rtol=1e-3, atol=1e-5) assert_allclose(x_2, y_2.asnumpy(), rtol=1e-3, atol=1e-5) + def test_binary_op(): a = mx.sym.Variable('a') b = mx.sym.Variable('b') @@ -1064,6 +1102,7 @@ def test_bneq(a, b): test_bpow(a, b) test_bneq(a, b) + def test_broadcast_binary_op(): a = mx.sym.Variable('a') b = mx.sym.Variable('b') @@ -1106,6 +1145,7 @@ def test_bequal(a, b): test_bpow(a, b) test_bequal(a, b) + def test_run_convolution_dilated_impulse_response(dil=(1,1), kernel_shape=(3,3), verbose=False): # Input for spike response spike_imgs = np.zeros(shape=(1,1,33,33), dtype=np.float32) @@ -1249,6 +1289,7 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape): exe.backward(out_grads=[mx.nd.array(out_grad_npy, ctx=default_context())]) assert_allclose(exe.grad_arrays[0].asnumpy(), out_grad_npy.reshape((5, 4, 3, 7))) + def test_reduce(): sample_num = 500 def test_reduce_inner(numpy_reduce_func, numpy_reduce_grad_func, mx_reduce_sym, nan_prob = 0): @@ -1334,6 +1375,7 @@ def test_reduce_inner(numpy_reduce_func, numpy_reduce_grad_func, mx_reduce_sym, outgrad.reshape(keepdim_shape) * (np.equal(data, outdata.reshape(keepdim_shape)).astype(np.float)), mx.symbol.min) + def test_broadcast(): sample_num = 200 for i in range(sample_num): @@ -1365,6 +1407,7 @@ def test_broadcasting_ele(sym_bcast): test_broadcasting_ele(sym_bcast_axis) test_broadcasting_ele(sym_bcast_to) + def test_transpose(): for ndim in range(1, 6): for t in range(5): @@ -1464,6 +1507,7 @@ def test_slice_axis(): xx[idx] = x.asnumpy()[idx] assert_allclose(xx + x_grad_npy, xgrad.asnumpy(), atol=1E-5) + def test_flip(): for ndim in range(1, 6): for t in range(5): @@ -1547,18 +1591,22 @@ def dot_sym(): x = mx.sym.Variable('x') y = mx.sym.Variable('y') return mx.sym.dot(x, y) + def dot_sym_xT(): x = mx.sym.Variable('x') y = mx.sym.Variable('y') return mx.sym.dot(x, y, transpose_a=True) + def dot_sym_yT(): x = mx.sym.Variable('x') y = mx.sym.Variable('y') return mx.sym.dot(x, y, transpose_b=True) + def dot_sym_xT_yT(): x = mx.sym.Variable('x') y = mx.sym.Variable('y') return mx.sym.dot(x, y, transpose_a=True, transpose_b=True) + for ashape, bshape in [((3, 4), (4, 5)), ((2,3,4), (4, 5, 6))]: m1_npy = np.random.uniform(-1, 1, ashape) m2_npy = np.random.uniform(-1, 1, bshape) @@ -1567,6 +1615,7 @@ def dot_sym_xT_yT(): check_numeric_gradient(dot_sym_yT(), [m1_npy, m2_npy.T], numeric_eps=1e-1, rtol=2e-2, atol=1e-3) check_numeric_gradient(dot_sym_xT_yT(), [m1_npy.T, m2_npy.T], numeric_eps=1e-1, rtol=2e-2, atol=1e-3) + def test_batch_dot(): for batch_size in range(1, 5): for m in range(1, 5): @@ -1615,6 +1664,7 @@ def test_batch_dot(): assert_almost_equal(exe_add.grad_dict['b'].asnumpy(), bgrad_npy + b_init_grad_npy, rtol=1e-3, atol=1e-4) + def get_correlation(data1,data2,kernel_size,max_displacement,stride1,stride2,pad_size,is_multiply): img1 = mx.sym.Variable('img1') @@ -1622,6 +1672,7 @@ def get_correlation(data1,data2,kernel_size,max_displacement,stride1,stride2,pad return mx.sym.Correlation(data1=img1,data2=img2,kernel_size =kernel_size,max_displacement = max_displacement, stride1 = stride1,stride2 = stride2,pad_size= pad_size,is_multiply = is_multiply) + def correlation_forward(data1,data2,pad_size,kernel_size,stride1,stride2,max_displacement,is_multiply): # compute output's dimension @@ -1669,6 +1720,7 @@ def correlation_forward(data1,data2,pad_size,kernel_size,stride1,stride2,max_dis out /= float(kernel_size**2*data1.shape[1]) return out,tmp1,tmp2 + def correlation_backward(out_grad,tmp1,tmp2,data1,data2,pad_size,kernel_size,stride1,stride2,max_displacement,is_multiply): # compute output's dimension @@ -1718,6 +1770,7 @@ def correlation_backward(out_grad,tmp1,tmp2,data1,data2,pad_size,kernel_size,str tmp2_grad = tmp2_grad / float(kernel_size**2*data1.shape[1]) return tmp1_grad[:,:,pad_size:pad_size+data1.shape[2],pad_size:pad_size+data1.shape[3]],tmp2_grad[:,:,pad_size:pad_size+data1.shape[2],pad_size:pad_size+data1.shape[3]], + def unittest_correlation(data_shape,kernel_size,max_displacement,stride1,stride2,pad_size,is_multiply): img1 = np.random.random(data_shape) @@ -1750,6 +1803,7 @@ def unittest_correlation(data_shape,kernel_size,max_displacement,stride1,stride2 assert_almost_equal(exe1.grad_dict['img1'].asnumpy(), grad1, rtol=1e-3, atol=1e-4) assert_almost_equal(exe1.grad_dict['img2'].asnumpy(), grad2, rtol=1e-3, atol=1e-4) + def test_correlation(): unittest_correlation((1,3,10,10), kernel_size = 1,max_displacement = 4,stride1 = 1,stride2 = 1,pad_size = 4,is_multiply = False) @@ -1791,6 +1845,7 @@ def test_support_vector_machine_l1_svm(): assert_almost_equal(grad_np, grad.asnumpy()) + def test_support_vector_machine_l2_svm(): xpu = default_context() shape = (20, 10) @@ -1838,6 +1893,7 @@ def test_roipooling(): grad_nodes={'data':'add', 'rois':'null'}, numeric_eps=1e-4, rtol=1e-1, atol=1E-4) + def check_pad_with_shape(shape, xpu, pad_width, mode): # bind with label X = mx.symbol.Variable('X') @@ -1856,6 +1912,7 @@ def check_pad_with_shape(shape, xpu, pad_width, mode): # grad check check_numeric_gradient(Y, [x.asnumpy()], numeric_eps=1e-2, rtol=1e-2) + def test_pad(): shape1 = (2, 3, 3, 5) pad1 = (0, 0, 0, 0, 1, 2, 3, 4) @@ -1868,6 +1925,7 @@ def test_pad(): check_pad_with_shape(shape1, default_context(), pad1, 'reflect') check_pad_with_shape(shape2, default_context(), pad2, 'reflect') + def np_instance_norm(data, weight, bias, eps): spatial_dims = data.shape[2::] num_spatial_vals = np.prod(np.array(spatial_dims)) @@ -1884,6 +1942,7 @@ def np_instance_norm(data, weight, bias, eps): biasBatch = np.reshape(np.repeat(biasBatch, num_spatial_vals), data.shape) return weightBatch * (data - mean)/np.sqrt(var + eps) + biasBatch + def check_instance_norm_with_shape(shape, xpu): # bind with label eps = 0.001 @@ -1904,12 +1963,14 @@ def check_instance_norm_with_shape(shape, xpu): check_numeric_gradient(Y, {'X':x.asnumpy(), 'G':gamma.asnumpy(), 'B':beta.asnumpy()}, numeric_eps=1e-2, rtol=1e-2, atol=1e-2) + def test_instance_normalization(): check_instance_norm_with_shape((1, 1, 1), default_context()) check_instance_norm_with_shape((2, 1, 2), default_context()) check_instance_norm_with_shape((2,4,5,6), default_context()) check_instance_norm_with_shape((3,3,2,3,2,1,1), default_context()) + def check_l2_normalization(in_shape, mode, ctx=default_context(), norm_eps=1e-10): data = mx.symbol.Variable('data') out = mx.symbol.L2Normalization(data=data, mode=mode, eps=norm_eps) @@ -1942,6 +2003,7 @@ def check_l2_normalization(in_shape, mode, ctx=default_context(), norm_eps=1e-10 # check gradient check_numeric_gradient(out, [in_data], numeric_eps=1e-3, rtol=1e-2, atol=1e-3) + def test_l2_normalization(): for mode in ['channel', 'spatial', 'instance']: for nbatch in [1, 4]: @@ -1951,6 +2013,7 @@ def test_l2_normalization(): for width in [5, 7]: check_l2_normalization((nbatch, nchannel, height, width), mode) + def sequence_mask_numpy(array, lengths, value): arrayMask = array.copy() shape = array.shape @@ -1959,6 +2022,7 @@ def sequence_mask_numpy(array, lengths, value): arrayMask[int(lengths[i]):, i] = value return arrayMask + def check_sequence_mask(shape, xpu, mask_value): # bind with label X = mx.symbol.Variable('X') @@ -1981,12 +2045,14 @@ def check_sequence_mask(shape, xpu, mask_value): check_numeric_gradient(Y, [x.asnumpy(), l.asnumpy()], grad_nodes={'X':'write'}, numeric_eps=1e-3, rtol=1e-2) + def test_sequence_mask(): shape1 = (4, 2, 2, 3) shape2 = (1, 2, 2, 3, 1, 1) check_sequence_mask(shape1, default_context(), 2.1) check_sequence_mask(shape2, default_context(), 0.1) + def mathematical_core_binary(name, forward_mxnet_call, forward_numpy_call, @@ -2031,6 +2097,7 @@ def mathematical_core_binary(name, assert_almost_equal(arr_grad1, npout_grad1) assert_almost_equal(arr_grad2, npout_grad2) + def mathematical_core(name, forward_mxnet_call, forward_numpy_call, backward_numpy_call, data_init=5., grad_init=2.): data = mx.symbol.Variable('data') shape = (3, 4) @@ -2059,6 +2126,7 @@ def mathematical_core(name, forward_mxnet_call, forward_numpy_call, backward_num # print(npout_grad) assert_almost_equal(arr_grad, npout_grad) + def test_special_functions_using_scipy(): try: from scipy import special as scipy_special @@ -2089,6 +2157,7 @@ def rounding(name, forward_mxnet_call, forward_numpy_call, data_init=5., grad_in npout = forward_numpy_call(data_tmp) assert_almost_equal(out, npout) + def test_mathematical(): # rsqrt mathematical_core("rsqrt", @@ -2175,6 +2244,7 @@ def test_mathematical(): # fix rounding("fix", lambda x: mx.sym.fix(x), lambda x: np.fix(x)) + def test_special_functions_using_scipy(): try: from scipy import special as scipy_special @@ -2190,6 +2260,7 @@ def test_special_functions_using_scipy(): mathematical_core("gammaln", lambda x: mx.sym.gammaln(x), lambda x: scipy_special.gammaln(x), lambda x: scipy_special.psi(x), 0.5, 0.5) + def test_clip(): data = mx.symbol.Variable('data') shape = (30, 30) @@ -2199,6 +2270,7 @@ def test_clip(): check_symbolic_backward(test, [data_tmp], [np.ones(shape)], [np.where(data_tmp < 0.6, [1], [0]) * np.where(data_tmp > -0.6, [1], [0])]) + def test_init(): def test_basic_val_init(sym_func, np_func, shape, dtype): x = sym_func(shape=shape, dtype=dtype) @@ -2333,6 +2405,7 @@ def test_blockgrad(): assert_almost_equal(exe.outputs[0].asnumpy(), a_npy) exe.backward() # No error if BlockGrad works + def test_take(): def check_output_n_grad(data_shape, idx_shape): exe = result.simple_bind(default_context(), a=data_shape, @@ -2597,6 +2670,7 @@ def bilinear_backward_numpy(out_grad, data, grid): assert_almost_equal(exe_addto.grad_dict['data'].asnumpy(), data_grad + data_initial_grid, rtol=1e-3,atol=1e-5) assert_almost_equal(exe_addto.grad_dict['grid'].asnumpy(), grid_grad + grid_initial_grid, rtol=1e-3,atol=1e-5) + def test_index2d(): for _ in range(30): n = np.random.randint(1, 100) @@ -2606,6 +2680,7 @@ def test_index2d(): r = mx.nd.batch_take(data, x) assert_almost_equal(r.asnumpy(), data.asnumpy()[np.arange(n), x.asnumpy()]) + def test_cast(): for srctype in [np.int32, np.float32, np.float16]: for dsttype in [np.float32, np.int32, np.float16]: @@ -3072,6 +3147,7 @@ def check_ctc_loss(acts, labels, loss_truth): # test grad check_numeric_gradient(ctc, [acts, labels], grad_nodes=['input'], rtol=0.05, atol=1e-3) + def test_ctc_loss(): # Test 1: check that batches are same + check against Torch WarpCTC acts = np.array([ diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 11ca7bed1743..f87a8c7cfc7f 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -30,24 +30,41 @@ def test_lr_wd_mult(): assert not mx.test_utils.almost_equal(args1['fc2_weight'], args2['fc2_weight'], 1e-1) -def compare_optimizer(opt1, opt2, shape): - w1 = mx.random.uniform(shape=shape, ctx=default_context()) - g1 = mx.random.uniform(shape=shape, ctx=default_context()) - - w2 = w1.copyto(default_context()) - g2 = g1.copyto(default_context()) +def compare_optimizer(opt1, opt2, shape, w_stype='default', g_stype='default'): + if w_stype == 'default': + w2 = mx.random.uniform(shape=shape, ctx=default_context()) + w1 = w2.copyto(default_context()) + elif w_stype == 'row_sparse': + w2 = rand_ndarray(shape, w_stype, density=1) + w1 = w2.copyto(default_context()).todense() + else: + raise Exception("type not supported yet") + if g_stype == 'default': + g2 = mx.random.uniform(shape=shape, ctx=default_context()) + g1 = g2.copyto(default_context()) + elif g_stype == 'row_sparse': + g2 = rand_ndarray(shape, g_stype) + g1 = g2.copyto(default_context()).todense() + else: + raise Exception("type not supported yet") state1 = opt1.create_state(0, w1) state2 = opt2.create_state(0, w2) if state1 is not None and state2 is not None: - for s1, s2, in zip(state1, state2): - assert(same(s1.asnumpy(), s2.asnumpy())) + if isinstance(state1, tuple): + for s1, s2, in zip(state1, state2): + assert(same(s1.asnumpy(), s2.asnumpy())) + else: + assert_almost_equal(state1.asnumpy(), state2.asnumpy()) opt1.update(0, w1, g1, state1) opt2.update(0, w2, g2, state2) if state1 is not None and state2 is not None: - for s1, s2, in zip(state1, state2): - assert_almost_equal(s1.asnumpy(), s2.asnumpy(), rtol=1e-4, atol=1e-5) + if isinstance(state1, tuple): + for s1, s2, in zip(state1, state2): + assert_almost_equal(s1.asnumpy(), s2.asnumpy(), rtol=1e-4, atol=1e-5) + else: + assert_almost_equal(state1.asnumpy(), state2.asnumpy()) assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=1e-4, atol=1e-5) # SGD @@ -130,6 +147,98 @@ def test_sgd(): for kwarg in kwargs: compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape) +class PySparseSGD(mx.optimizer.Optimizer): + """python reference implemenation of sgd""" + def __init__(self, learning_rate=0.01, momentum=0.0, **kwargs): + super(PySparseSGD, self).__init__(learning_rate=learning_rate, **kwargs) + self.momentum = momentum + + def create_state(self, index, weight): + """Create additional optimizer state: momentum + + Parameters + ---------- + weight : NDArray + The weight data + + """ + if self.momentum == 0.0: + return None + else: + return mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype) + + def update(self, index, weight, grad, state): + """Update the parameters. + + Parameters + ---------- + index : int + An unique integer key used to index the parameters + + weight : NDArray + weight ndarray + + grad : NDArray + grad ndarray + + state : NDArray or other objects returned by init_state + The auxiliary state used in optimization. + """ + lr = self._get_lr(index) + wd = self._get_wd(index) + self._update_count(index) + num_rows = weight.shape[0] + if self.momentum == 0.0: + # Update on a per row basis, skip all-zero rows + for row in range(num_rows): + grad_row = grad[row].asnumpy() + all_zeros = mx.test_utils.almost_equal(grad_row, np.zeros_like(grad_row)) + if all_zeros: + continue + if self.clip_gradient is not None: + weight[row] = ((1 - lr*wd)*weight[row] - + lr*mx.nd.clip(grad[row]*self.rescale_grad, + -self.clip_gradient, self.clip_gradient)) + else: + weight[row] = (1 - lr*wd)*weight[row] - lr*self.rescale_grad*grad[row] + else: + mom = state + for row in range(num_rows): + grad_row = grad[row].asnumpy() + all_zeros = mx.test_utils.almost_equal(grad_row, np.zeros_like(grad_row)) + if all_zeros: + continue + if self.clip_gradient is not None: + mom[row] = (self.momentum*mom[row] - lr*wd*weight[row] - + lr*mx.nd.clip(grad[row]*self.rescale_grad, -self.clip_gradient, self.clip_gradient)) + weight[row] += mom[row] + else: + mom[row] = self.momentum*mom[row] - lr*wd*weight[row] - lr*self.rescale_grad*grad[row] + weight[row] += mom[row] + +def test_sparse_sgd(): + mx.random.seed(0) + opt1 = PySparseSGD + opt2 = mx.optimizer.SGD + shape = (3, 4) + kwargs = [{}, + {'momentum': 0.9}, + {'clip_gradient': 0.5}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14}, + {'rescale_grad': 0.8}, + {'clip_gradient': 0.5, 'wd': 0.07}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'wd': 0.03}, + {'rescale_grad': 0.8, 'wd': 0.05}, + {'clip_gradient': 0.5, 'momentum': 0.9}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'momentum': 0.9}, + {'rescale_grad': 0.8, 'momentum': 0.9}, + {'clip_gradient': 0.5, 'wd': 0.07, 'momentum': 0.9}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'wd': 0.03, 'momentum': 0.9}, + {'rescale_grad': 0.8, 'wd': 0.05, 'momentum': 0.9}] + for kwarg in kwargs: + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, w_stype='row_sparse', g_stype='row_sparse') + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, w_stype='row_sparse', g_stype='default') + # ADAM class PyAdam(mx.optimizer.Optimizer): @@ -354,3 +463,4 @@ def test_rms(): test_adam() test_rms() test_sgd() + test_sparse_sgd() diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py new file mode 100644 index 000000000000..a09857b95efe --- /dev/null +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -0,0 +1,399 @@ +import pickle as pkl + +from mxnet.ndarray import NDArray +from mxnet.test_utils import * +from numpy.testing import assert_allclose +import numpy.random as rnd + +from mxnet.sparse_ndarray import RowSparseNDArray, CSRNDArray, _ndarray_cls + + +def assert_fcompex(f, *args, **kwargs): + prev_val = mx.test_utils.set_env_var("MXNET_EXEC_STORAGE_FALLBACK", "0", "1") + f(*args, **kwargs) + mx.test_utils.set_env_var("MXNET_EXEC_STORAGE_FALLBACK", prev_val) + + +def sparse_nd_ones(shape, stype): + return mx.nd.cast_storage(mx.nd.ones(shape), storage_type=stype) + + +def check_sparse_nd_elemwise_binary(shapes, storage_types, f, g): + # generate inputs + nds = [] + for i, storage_type in enumerate(storage_types): + if storage_type == 'row_sparse': + nd, _ = rand_sparse_ndarray(shapes[i], storage_type) + elif storage_type == 'default': + nd = mx.nd.array(random_arrays(shapes[i]), dtype = np.float32) + else: + assert(False) + nds.append(nd) + # check result + test = f(nds[0], nds[1]) + assert_almost_equal(test.asnumpy(), g(nds[0].asnumpy(), nds[1].asnumpy())) + + +def test_sparse_nd_elemwise_add(): + num_repeats = 10 + g = lambda x,y: x + y + op = mx.nd.elemwise_add + for i in range(num_repeats): + shape = [rand_shape_2d()] * 2 + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['default'] * 2, op, g) + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['default', 'row_sparse'], op, g) + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['row_sparse', 'row_sparse'], op, g) + + +# Test a operator which doesn't implement FComputeEx +def test_sparse_nd_elementwise_fallback(): + num_repeats = 10 + g = lambda x,y: x + y + op = mx.nd.add_n + for i in range(num_repeats): + shape = [rand_shape_2d()] * 2 + check_sparse_nd_elemwise_binary(shape, ['default'] * 2, op, g) + check_sparse_nd_elemwise_binary(shape, ['default', 'row_sparse'], op, g) + check_sparse_nd_elemwise_binary(shape, ['row_sparse', 'row_sparse'], op, g) + + +def test_sparse_nd_zeros(): + def check_sparse_nd_zeros(stype, shape): + zero = mx.nd.zeros(shape) + sparse_zero = mx.nd.zeros(shape=shape, storage_type=stype) + assert_almost_equal(sparse_zero.asnumpy(), zero.asnumpy()) + + shape = rand_shape_2d() + check_sparse_nd_zeros('row_sparse', shape) + check_sparse_nd_zeros('csr', shape) + check_sparse_nd_zeros('default', shape) + + +def test_sparse_nd_copy(): + def check_sparse_nd_copy(from_stype, to_stype): + shape = rand_shape_2d() + from_nd = rand_ndarray(shape, from_stype) + # copy to ctx + to_ctx = from_nd.copyto(default_context()) + # copy to stype + to_nd = rand_ndarray(shape, to_stype) + to_nd = from_nd.copyto(to_nd) + assert np.sum(np.abs(from_nd.asnumpy() != to_ctx.asnumpy())) == 0.0 + assert np.sum(np.abs(from_nd.asnumpy() != to_nd.asnumpy())) == 0.0 + + check_sparse_nd_copy('row_sparse', 'row_sparse') + check_sparse_nd_copy('row_sparse', 'default') + check_sparse_nd_copy('default', 'row_sparse') + check_sparse_nd_copy('default', 'csr') + + +def check_sparse_nd_prop_rsp(): + storage_type = 'row_sparse' + shape = rand_shape_2d() + nd, (v, idx) = rand_sparse_ndarray(shape, storage_type) + assert(nd._num_aux == 1) + assert(nd.indices.dtype == np.int64) + assert(nd.storage_type == 'row_sparse') + assert_almost_equal(nd.indices.asnumpy(), idx) + + +def test_sparse_nd_basic(): + def check_rsp_creation(values, indices, shape): + rsp = mx.sparse_nd.row_sparse(values, indices, shape) + dns = mx.nd.zeros(shape) + dns[1] = mx.nd.array(values[0]) + dns[3] = mx.nd.array(values[1]) + #assert_almost_equal(rsp.asnumpy(), dns.asnumpy()) + print('before', indices) + print('mx', mx.nd.array(indices, dtype='int64')[1].asnumpy()) + indices_np = mx.nd.array(indices, dtype='int64').asnumpy() + print('after', indices_np) + assert_almost_equal(rsp.indices.asnumpy(), indices_np) + + def check_csr_creation(shape): + csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr') + assert_almost_equal(csr.indptr.asnumpy(), indptr) + assert_almost_equal(csr.indices.asnumpy(), indices) + assert_almost_equal(csr.values.asnumpy(), values) + + shape = (4,2) + values = np.random.rand(2,2) + indices = np.array([1,3], dtype='int64') + check_rsp_creation(values, indices, shape) + + values = mx.nd.array(np.random.rand(2,2)) + indices = mx.nd.array([1,3], dtype='int64') + check_rsp_creation(values, indices, shape) + + values = [[0.1, 0.2], [0.3, 0.4]] + indices = [1,3] + check_rsp_creation(values, indices, shape) + + check_csr_creation(shape) + check_sparse_nd_prop_rsp() + + +def test_sparse_nd_setitem(): + def check_sparse_nd_setitem(storage_type, shape, dst): + x = mx.nd.zeros(shape=shape, storage_type=storage_type) + x[:] = dst + dst_nd = mx.nd.array(dst) if isinstance(dst, (np.ndarray, np.generic)) else dst + assert same(x.asnumpy(), dst_nd.asnumpy()) + + shape = rand_shape_2d() + for stype in ['row_sparse', 'csr']: + # ndarray assignment + check_sparse_nd_setitem(stype, shape, rand_ndarray(shape, 'default')) + check_sparse_nd_setitem(stype, shape, rand_ndarray(shape, stype)) + # numpy assignment + check_sparse_nd_setitem(stype, shape, np.ones(shape)) + + +def test_sparse_nd_slice(): + def check_sparse_nd_csr_slice(shape): + storage_type = 'csr' + A, _ = rand_sparse_ndarray(shape, storage_type) + A2 = A.asnumpy() + start = rnd.randint(0, shape[0] - 1) + end = rnd.randint(start + 1, shape[0]) + assert same(A[start:end].asnumpy(), A2[start:end]) + assert same(A[start:].asnumpy(), A2[start:]) + assert same(A[:end].asnumpy(), A2[:end]) + + shape = (rnd.randint(2, 10), rnd.randint(1, 10)) + check_sparse_nd_csr_slice(shape) + + +def test_sparse_nd_equal(): + for stype in ['row_sparse', 'csr']: + shape = rand_shape_2d() + x = mx.nd.zeros(shape=shape, storage_type=stype) + y = sparse_nd_ones(shape, stype) + z = x == y + assert (z.asnumpy() == np.zeros(shape)).all() + z = 0 == x + assert (z.asnumpy() == np.ones(shape)).all() + + +def test_sparse_nd_not_equal(): + for stype in ['row_sparse', 'csr']: + shape = rand_shape_2d() + x = mx.nd.zeros(shape=shape, storage_type=stype) + y = sparse_nd_ones(shape, stype) + z = x != y + assert (z.asnumpy() == np.ones(shape)).all() + z = 0 != x + assert (z.asnumpy() == np.zeros(shape)).all() + + +def test_sparse_nd_greater(): + for stype in ['row_sparse', 'csr']: + shape = rand_shape_2d() + x = mx.nd.zeros(shape=shape, storage_type=stype) + y = sparse_nd_ones(shape, stype) + z = x > y + assert (z.asnumpy() == np.zeros(shape)).all() + z = y > 0 + assert (z.asnumpy() == np.ones(shape)).all() + z = 0 > y + assert (z.asnumpy() == np.zeros(shape)).all() + + +def test_sparse_nd_greater_equal(): + for stype in ['row_sparse', 'csr']: + shape = rand_shape_2d() + x = mx.nd.zeros(shape=shape, storage_type=stype) + y = sparse_nd_ones(shape, stype) + z = x >= y + assert (z.asnumpy() == np.zeros(shape)).all() + z = y >= 0 + assert (z.asnumpy() == np.ones(shape)).all() + z = 0 >= y + assert (z.asnumpy() == np.zeros(shape)).all() + z = y >= 1 + assert (z.asnumpy() == np.ones(shape)).all() + + +def test_sparse_nd_lesser(): + for stype in ['row_sparse', 'csr']: + shape = rand_shape_2d() + x = mx.nd.zeros(shape=shape, storage_type=stype) + y = sparse_nd_ones(shape, stype) + z = y < x + assert (z.asnumpy() == np.zeros(shape)).all() + z = 0 < y + assert (z.asnumpy() == np.ones(shape)).all() + z = y < 0 + assert (z.asnumpy() == np.zeros(shape)).all() + + +def test_sparse_nd_lesser_equal(): + for stype in ['row_sparse', 'csr']: + shape = rand_shape_2d() + x = mx.nd.zeros(shape=shape, storage_type=stype) + y = sparse_nd_ones(shape, stype) + z = y <= x + assert (z.asnumpy() == np.zeros(shape)).all() + z = 0 <= y + assert (z.asnumpy() == np.ones(shape)).all() + z = y <= 0 + assert (z.asnumpy() == np.zeros(shape)).all() + z = 1 <= y + assert (z.asnumpy() == np.ones(shape)).all() + + +def test_sparse_nd_binary(): + N = 100 + def check_binary(fn): + for _ in range(N): + ndim = 2 + oshape = np.random.randint(1, 6, size=(ndim,)) + bdim = 2 + lshape = list(oshape) + rshape = list(oshape[ndim-bdim:]) + for i in range(bdim): + sep = np.random.uniform(0, 1) + if sep < 0.33: + lshape[ndim-i-1] = 1 + elif sep < 0.66: + rshape[bdim-i-1] = 1 + lhs = np.random.uniform(0, 1, size=lshape) + rhs = np.random.uniform(0, 1, size=rshape) + lhs_nd_csr = mx.nd.array(lhs)._to_csr() + rhs_nd_csr = mx.nd.array(rhs)._to_csr() + lhs_nd_rsp = mx.nd.array(lhs)._to_rsp() + rhs_nd_rsp = mx.nd.array(rhs)._to_rsp() + for lhs_nd, rhs_nd in [(lhs_nd_csr, rhs_nd_csr), (lhs_nd_rsp, rhs_nd_rsp)]: + assert_allclose(fn(lhs, rhs), + fn(lhs_nd, rhs_nd).asnumpy(), + rtol=1e-4, atol=1e-4) + + check_binary(lambda x, y: x + y) + check_binary(lambda x, y: x - y) + check_binary(lambda x, y: x * y) + check_binary(lambda x, y: x / y) + check_binary(lambda x, y: x ** y) + check_binary(lambda x, y: x > y) + check_binary(lambda x, y: x < y) + check_binary(lambda x, y: x >= y) + check_binary(lambda x, y: x <= y) + check_binary(lambda x, y: x == y) + + +def test_sparse_nd_binary_rop(): + N = 100 + def check(fn): + for _ in range(N): + ndim = 2 + shape = np.random.randint(1, 6, size=(ndim,)) + npy_nd = np.random.normal(0, 1, size=shape) + csr_nd = mx.nd.array(npy_nd)._to_csr() + rsp_nd = mx.nd.array(npy_nd)._to_rsp() + for sparse_nd in [csr_nd, rsp_nd]: + assert_allclose( + fn(npy_nd), + fn(sparse_nd).asnumpy(), + rtol=1e-4, + atol=1e-4 + ) + check(lambda x: 1 + x) + check(lambda x: 1 - x) + check(lambda x: 1 * x) + check(lambda x: 1 / x) + check(lambda x: 2 ** x) + check(lambda x: 1 > x) + check(lambda x: 0.5 > x) + check(lambda x: 0.5 < x) + check(lambda x: 0.5 >= x) + check(lambda x: 0.5 <= x) + check(lambda x: 0.5 == x) + + +def test_sparse_nd_negate(): + npy = np.random.uniform(-10, 10, rand_shape_2d()) + arr_csr = mx.nd.array(npy)._to_csr() + arr_rsp = mx.nd.array(npy)._to_rsp() + for arr in [arr_csr, arr_rsp]: + assert_almost_equal(npy, arr.asnumpy()) + assert_almost_equal(-npy, (-arr).asnumpy()) + + # a final check to make sure the negation (-) is not implemented + # as inplace operation, so the contents of arr does not change after + # we compute (-arr) + assert_almost_equal(npy, arr.asnumpy()) + + +def test_sparse_nd_output_fallback(): + shape = (10, 10) + out = mx.nd.zeros(shape=shape, storage_type='row_sparse') + mx.nd.random_normal(shape=shape, out=out) + assert(np.sum(out.asnumpy()) != 0) + + +def test_sparse_nd_astype(): + stypes = ['row_sparse', 'csr'] + for stype in stypes: + x = mx.nd.zeros(shape=rand_shape_2d(), storage_type=stype, dtype='float32') + y = x.astype('int32') + assert(y.dtype == np.int32), y.dtype + + +def test_sparse_ndarray_pickle(): + np.random.seed(0) + repeat = 10 + dim0 = 40 + dim1 = 40 + stypes = ['row_sparse', 'csr'] + densities = [0, 0.01, 0.1, 0.2, 0.5] + stype_dict = {'row_sparse': RowSparseNDArray, 'csr': CSRNDArray} + for _ in range(repeat): + shape = rand_shape_2d(dim0, dim1) + for stype in stypes: + for density in densities: + a, _ = rand_sparse_ndarray(shape, stype, density) + assert isinstance(a, stype_dict[stype]) + data = pkl.dumps(a) + b = pkl.loads(data) + assert isinstance(b, stype_dict[stype]) + assert same(a.asnumpy(), b.asnumpy()) + + +def test_sparse_ndarray_save_load(): + np.random.seed(0) + repeat = 1 + stypes = ['default', 'row_sparse', 'csr'] + stype_dict = {'default': NDArray, 'row_sparse': RowSparseNDArray, 'csr': CSRNDArray} + num_data = 20 + densities = [0, 0.01, 0.1, 0.2, 0.5] + fname = 'tmp_list.bin' + for _ in range(repeat): + data_list1 = [] + for i in range(num_data): + stype = stypes[np.random.randint(0, len(stypes))] + shape = rand_shape_2d(dim0=40, dim1=40) + density = densities[np.random.randint(0, len(densities))] + data_list1.append(rand_ndarray(shape, stype, density)) + assert isinstance(data_list1[-1], stype_dict[stype]) + mx.nd.save(fname, data_list1) + + data_list2 = mx.nd.load(fname) + assert len(data_list1) == len(data_list2) + for x, y in zip(data_list1, data_list2): + assert same(x.asnumpy(), y.asnumpy()) + + data_map1 = {'ndarray xx %s' % i: x for i, x in enumerate(data_list1)} + mx.nd.save(fname, data_map1) + data_map2 = mx.nd.load(fname) + assert len(data_map1) == len(data_map2) + for k, x in data_map1.items(): + y = data_map2[k] + assert same(x.asnumpy(), y.asnumpy()) + os.remove(fname) + + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py new file mode 100644 index 000000000000..55667225fd35 --- /dev/null +++ b/tests/python/unittest/test_sparse_operator.py @@ -0,0 +1,205 @@ +from mxnet.test_utils import * + + +def check_elemwise_add_ex(lhs_stype, rhs_stype, shape, lhs_grad_stype=None, rhs_grad_stype=None): + lhs = mx.symbol.Variable('lhs', storage_type=lhs_stype) + rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype) + if lhs_grad_stype is not None: + lhs._set_attr(grad_stype_hint=str(lhs_grad_stype)) + if rhs_grad_stype is not None: + rhs._set_attr(grad_stype_hint=str(rhs_grad_stype)) + + lhs_nd = rand_ndarray(shape, lhs_stype) + rhs_nd = rand_ndarray(shape, rhs_stype) + lhs_np = lhs_nd.asnumpy() + rhs_np = rhs_nd.asnumpy() + + out_np = lhs_np + rhs_np + test = mx.symbol.elemwise_add(lhs, rhs) + location = {'lhs': lhs_nd, 'rhs': rhs_nd} + check_symbolic_forward(test, location, [out_np]) + check_numeric_gradient(test, location) + check_symbolic_backward(test, location, [out_np], [out_np, out_np]) + + +def test_elemwise_add_ex(): + shape = rand_shape_2d() + check_elemwise_add_ex('default', 'default', shape) + check_elemwise_add_ex('default', 'row_sparse', shape) + check_elemwise_add_ex('row_sparse', 'default', shape) + check_elemwise_add_ex('row_sparse', 'row_sparse', shape, + lhs_grad_stype='row_sparse', rhs_grad_stype='row_sparse') + + +# TODO(haibin) randomize this test +def test_elemwise_add_ex_multiple_stages(): + # prep data + shape = (4, 2) + ds_np = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + sp_np1 = np.array([[5, 10], [0, 0], [0, 0], [0, 0]]) + sp_np2 = np.array([[0, 0], [5, 10], [0, 0], [0, 0]]) + + val1 = mx.nd.array([[5, 10]]); + val2 = mx.nd.array([[5, 10]]); + idx1 = mx.nd.array([0], dtype=np.int64); + idx2 = mx.nd.array([1], dtype=np.int64); + sp_nd1 = mx.sparse_nd.row_sparse(val1, idx1, shape) + sp_nd2 = mx.sparse_nd.row_sparse(val2, idx2, shape) + ds_nd = mx.nd.array(ds_np) + + # sparse + sparse = sparse + sp_data1 = mx.symbol.Variable('sp_data1', storage_type='row_sparse') + sp_data2 = mx.symbol.Variable('sp_data2', storage_type='row_sparse') + ds_data = mx.symbol.Variable('ds_data') + plus = mx.symbol.elemwise_add(sp_data1, sp_data2, name='plus') + # sparse + dense = dense + test = mx.symbol.elemwise_add(plus, ds_data) + check_symbolic_forward(test, {'sp_data1': sp_nd1, 'sp_data2': sp_nd2, + 'ds_data': ds_nd}, [sp_np1 + sp_np2 + ds_np]) + + arr_grads = [mx.nd.zeros(shape) for i in range(3)] + exec_test = test.bind(default_context(), args={'sp_data1': sp_nd1, 'sp_data2': sp_nd2, + 'ds_data': ds_nd}, args_grad=arr_grads) + exec_test.forward(is_train=True) + assert_almost_equal(exec_test.outputs[0].asnumpy(), sp_np1 + sp_np2 + ds_np) + exec_test.backward(out_grads=exec_test.outputs) + assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy()) + +# TODO(haibin) also add test for backward pass. +def test_cast_storage_ex(): + def test_rsp_to_dns(shape): + rsp, (data, row_idx) = rand_sparse_ndarray(shape, 'row_sparse') + dns_out = mx.nd.cast_storage(rsp, storage_type='default') + dns_expected = np.zeros(shape, dtype=default_dtype()) + if row_idx is not None: + for k, v in enumerate(row_idx): + dns_expected[v, :] = data[k] + assert same(dns_out.asnumpy(), dns_expected) + + def test_dns_to_rsp(shape): + dns_in = rand_ndarray(shape, 'default') + rsp_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), storage_type='row_sparse') + ret = mx.nd.cast_storage(rsp_out, storage_type='default') + assert same(ret.asnumpy(), dns_in.asnumpy()) + + def test_csr_to_dns(shape): + csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr') + mx_dns = csr.todense() + np_dns = sp.csr_matrix((values, indices, indptr), shape).todense() + assert_almost_equal(mx_dns.asnumpy(), np_dns) + + def test_dns_to_csr(dns_in): + dns_in = np.array(dns_in) + csr_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), storage_type='csr') + ret = mx.nd.cast_storage(csr_out, storage_type='default') + assert same(ret.asnumpy(), dns_in) + + shape = rand_shape_2d() + test_rsp_to_dns(shape) + test_dns_to_rsp(shape) + test_csr_to_dns((4, 4)) + test_dns_to_csr([[0, 1, 0], [0, 2, 0], [3, 0, 0], [0, 0, 4], [5, 6, 0], [0, 0, 7]]) + +def test_sparse_dot(): + def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs): + lhs_dns = rand_ndarray(lhs_shape, 'default') + lhs_nd = mx.nd.cast_storage(lhs_dns, storage_type='csr') + rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=1) + rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense() + out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs) + assert out.storage_type == 'default' + out_expected = mx.nd.dot(lhs_dns, rhs_dns, transpose_a=trans_lhs) + out_np = out_expected.asnumpy() + backward_trans = not trans_lhs + rhs_backward_grad = mx.nd.dot(lhs_dns, out_expected, transpose_a=backward_trans).asnumpy() + assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5) + + # test symbolic forward + lhs = mx.symbol.Variable('lhs', storage_type='csr') + rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype) + test = mx.symbol.dot(lhs, rhs, transpose_a=trans_lhs) + location = {'lhs': lhs_nd, 'rhs': rhs_nd} + expected = {'rhs': rhs_backward_grad} + check_symbolic_forward(test, location, [out_np], rtol=1e-3, atol=1e-4) + # test symbolic backward + check_symbolic_backward(test, location, [out_np], expected, + grad_req={'lhs': 'null', 'rhs': 'write'}, + rtol=1e-3, atol=1e-4) + + lhs_shape = rand_shape_2d() + test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'default', False) + test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'default', True) + test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False) + test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True) + +def test_sparse_embedding(): + in_dim = 10 + out_dim = 4 + batch = 24 + + data = mx.sym.Variable("data", storage_type='csr') + embed = mx.sym.SparseEmbedding(data=data, input_dim=in_dim, output_dim=out_dim, name="embed") + exe_test = embed.simple_bind(default_context(), grad_req={'data': 'null', 'embed_weight': 'write'}, + data=(batch, in_dim)) + + arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays)) + grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays)) + np_data = np.random.randint(low=0, high=in_dim, size=batch) + np_weight = np.random.uniform(-0.01, 0.01, arg_map["embed_weight"].shape) + np_onehot = np.zeros((batch, in_dim)) + np_onehot[np.arange(batch), np_data] = 1.0 + nd_onehot = mx.nd.array(np_onehot)._to_csr() + # forward + arg_map["data"][:] = nd_onehot + arg_map["embed_weight"][:] = np_weight + exe_test.forward(is_train=True) + assert_almost_equal(exe_test.outputs[0].asnumpy(), np.dot(np_onehot, np_weight)) + # backward + np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape) + grad = mx.nd.zeros(np_grad.shape) + grad[:] = np_grad + exe_test.backward([grad]) + assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, np_grad), atol=1e-5) + + +def test_sparse_slice(): + def check_csr_slice(shape, slice_input): + storage_type = 'csr' + A, _ = rand_sparse_ndarray(shape, storage_type) + B = A._slice(1, shape[0] - 1) if slice_input else A + np = B.asnumpy() + begin = rnd.randint(0, B.shape[0] - 1) + end = rnd.randint(begin + 1, B.shape[0]) + nd_slice = mx.nd.crop(B, begin=begin, end=end) + assert same(nd_slice.asnumpy(), np[begin:end]), (nd_slice.asnumpy(), np[begin:end]) + + shape = (rnd.randint(7, 15), rnd.randint(1, 10)) + check_csr_slice(shape, True) + check_csr_slice(shape, False) + + +def test_sparse_retain(): + for _ in range(10): + shape = rand_shape_2d() + num_rows = shape[0] + rsp, _ = rand_sparse_ndarray(shape=shape, storage_type='row_sparse', density=0.5) + length = np.random.randint(1, num_rows + 1) + idx = random_sample(list(range(0, num_rows)), length) + idx.sort() + dns = rsp.asnumpy() + tensor_retained_expected = np.zeros(shape) + for i in idx: + tensor_retained_expected[i][:] = dns[i] + indices = mx.nd.array(idx) + rsp_retained = mx.nd.sparse_retain(rsp, indices=indices) + assert same(tensor_retained_expected, rsp_retained.asnumpy()) + + # check numeric gradient + data = mx.symbol.Variable('data') + idx = mx.symbol.Variable('indices') + sym = mx.sym.sparse_retain(data=data, indices=idx) + check_numeric_gradient(sym, [rsp, indices], grad_nodes=['data'], grad_stype_dict={'data': 'row_sparse'}) + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index cff4196b6043..d0ee09312cd4 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -109,11 +109,11 @@ if [ ${TASK} == "python_test" ]; then python -m nose tests/python/doctest || exit -1 python3 -m nose tests/python/doctest || exit -1 else - nosetests tests/python/unittest || exit -1 - nosetests3 tests/python/unittest || exit -1 - nosetests3 tests/python/train || exit -1 - nosetests tests/python/doctest || exit -1 - nosetests3 tests/python/doctest || exit -1 + nosetests -v tests/python/unittest || exit -1 + nosetests3 -v tests/python/unittest || exit -1 + nosetests3 -v tests/python/train || exit -1 + nosetests -v tests/python/doctest || exit -1 + nosetests3 -v tests/python/doctest || exit -1 fi exit 0 fi diff --git a/tests/travis/setup.sh b/tests/travis/setup.sh index ec071009bda5..7c9d137b8269 100755 --- a/tests/travis/setup.sh +++ b/tests/travis/setup.sh @@ -15,8 +15,8 @@ if [ ${TRAVIS_OS_NAME} == "osx" ]; then brew install ImageMagick brew install swig if [ ${TASK} == "python_test" ]; then - python -m pip install --user nose numpy cython - python3 -m pip install --user nose numpy cython + python -m pip install --user nose numpy cython scipy + python3 -m pip install --user nose numpy cython scipy fi fi