From 57da011390a3c8db47908abe0a1728bd6a7c94d2 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Wed, 2 Aug 2017 23:16:25 +0000 Subject: [PATCH] code changes according to review comments remove executor debug. add doc to optimizer update sparse sgd test add dtype option to rand_sparse_ndarray --- include/mxnet/c_api.h | 24 ++++++++ include/mxnet/ndarray.h | 1 - python/mxnet/_ctypes/ndarray.py | 15 +++-- python/mxnet/kvstore.py | 3 +- python/mxnet/ndarray/__init__.py | 2 +- python/mxnet/ndarray/sparse_ndarray.py | 56 ++++++++++++------- python/mxnet/optimizer.py | 4 +- python/mxnet/test_utils.py | 16 +++--- src/c_api/c_api_ndarray.cc | 18 ++++++ src/common/utils.h | 2 +- src/executor/graph_executor.cc | 6 -- src/operator/elemwise_op_common.h | 15 ----- src/operator/optimizer_op-inl.h | 5 +- src/operator/tensor/elemwise_unary_op.h | 15 +++++ tests/python/unittest/test_module.py | 4 +- tests/python/unittest/test_optimizer.py | 44 ++++++++------- tests/python/unittest/test_sparse_ndarray.py | 4 +- tests/python/unittest/test_sparse_operator.py | 4 +- 18 files changed, 149 insertions(+), 89 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index a7d6e2033df2..e76057a1c437 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -719,6 +719,30 @@ MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, * \brief free cached operator */ MXNET_DLL int MXFreeCachedOp(CachedOpHandle handle); +/*! + * \brief invoke cached operator + */ +MXNET_DLL int MXInvokeCachedOp(CachedOpHandle handle, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs); +/*! + * \brief invoke a cached op + * \param handle the handle to the cached op + * \param num_inputs number of input NDArrays + * \param inputs input NDArrays + * \param num_outputs number of output NDArrays + * \param outputs output NDArrays + * \param out_stypes output ndarrays' stypes + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs, + const int** out_stypes); /*! * \brief invoke cached operator */ diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 2215acb3423e..3b4ba2147b8a 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -851,7 +851,6 @@ size_t num_aux_data(NDArrayStorageType stype); * \param from the ndarray we want to copy data from * \param to the target ndarray * \param priority Priority of the action. - * \param alloc_output whether to allocate memory for the output ndarray * \note The function name explicitly marks the order of from and to * due to different possible convention carried by copy function. */ diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index 2a82c0899ea2..622d843bb92b 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -127,17 +127,24 @@ def __call__(self, *args, **kwargs): "CachedOp.__call__ got unexpected keyword argument(s): " + \ ', '.join(kwargs.keys())) - check_call(_LIB.MXInvokeCachedOp( + # return output stypes to avoid the c_api call for checking + # a handle's stype in _ndarray_cls + out_stypes = ctypes.POINTER(ctypes.c_int)() + + check_call(_LIB.MXInvokeCachedOpEx( self.handle, ctypes.c_int(len(args)), c_array(NDArrayHandle, [arr.handle for arr in args]), ctypes.byref(num_output), - ctypes.byref(output_vars))) + ctypes.byref(output_vars), + ctypes.byref(out_stypes))) if original_output is not None: return original_output if num_output.value == 1: - return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle)) + return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle), + stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]]) else: - return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle)) + return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), + stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]]) for i in range(num_output.value)] diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 8d96c751ccb3..65ac460e8ff5 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -255,11 +255,12 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): row_ids : NDArray or list of NDArray The row_ids for which to pull for each value. The row_ids doesn't have to be unique or sorted. + Examples -------- >>> shape = (3, 3) >>> kv.init('3', mx.nd.ones(shape)._to_rsp()) - >>> a = mx.nd.zeros(shape) + >>> a = mx.nd.zeros(shape, stype='row_sparse') >>> row_ids = mx.nd.array([0, 2], dtype='int64') >>> kv.row_sparse_pull('3', out=a, row_ids=row_ids) >>> print a.asnumpy() diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py index f8d77dcdeaa9..717888f0b528 100644 --- a/python/mxnet/ndarray/__init__.py +++ b/python/mxnet/ndarray/__init__.py @@ -9,4 +9,4 @@ from .ndarray import onehot_encode, power, subtract, true_divide, waitall, _new_empty_handle from .ndarray_utils import load, save, zeros, empty, array from .sparse_ndarray import _ndarray_cls, todense -from .sparse_ndarray import csr, row_sparse, BaseSparseNDArray, RowSparseNDArray, CSRNDArray +from .sparse_ndarray import csr_matrix, row_sparse_array, BaseSparseNDArray, RowSparseNDArray, CSRNDArray diff --git a/python/mxnet/ndarray/sparse_ndarray.py b/python/mxnet/ndarray/sparse_ndarray.py index 195f56283c25..a68ffdb839b4 100644 --- a/python/mxnet/ndarray/sparse_ndarray.py +++ b/python/mxnet/ndarray/sparse_ndarray.py @@ -115,7 +115,7 @@ def __setitem__(self, key, value): Examples -------- - >>> src = mx.nd.row_sparse([[1, 0, 2], [4, 5, 6]], [0, 2], (3,3)) + >>> src = mx.nd.row_sparse_array([[1, 0, 2], [4, 5, 6]], [0, 2], (3,3)) >>> src.asnumpy() array([[ 1., 0., 2.], [ 0., 0., 0.], @@ -325,7 +325,9 @@ def _aux_data(self, i): # pylint: disable=abstract-method class CSRNDArray(BaseSparseNDArray): - """A CSRNDArray represents a NDArray as three separate arrays: `data`, + """A sparse representation of tensor in standard CSR format. + + A CSRNDArray represents an NDArray as three separate arrays: `data`, `indptr` and `indices`. It uses the standard CSR representation where the column indices for row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored in values[indptr[i]:indptr[i+1]]. @@ -391,15 +393,23 @@ def indptr(self): # pylint: disable=abstract-method class RowSparseNDArray(BaseSparseNDArray): - """A RowSparseNDArray is typically used to represent a subset of a larger - NDArray with `default` of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0. The values - in indices are the indices in the first dimension of the slices that have been extracted from - the larger NDArray. The indices are expected to be sorted in ascending order. + """A sparse representation of a set of tensor slices at given indices. - The corresponding NDArray ``dense`` with `default` storage represented by a ``rsp`` - RowSparseNDArray + A RowSparseNDArray represents an K-dimensional NDArray as two separate arrays: `data` and + `indices`. The `indices` stores the indices of the K-dimensional `data` slices extracted + from the dense NDArray in the first dimension. The corresponding NDArray ``dense`` + represented by RowSparseNDArray ``rsp`` has + + ``dense[rsp.indices[i], :, :, :, ...] = rsp.data[i, :, :, :, ...]``, - ``dense[rsp.indices[i], :, :, :, ...] = rsp.values[i, :, :, :, ...]`` + where `indices` is an 1-D integer NDArray with shape [D0], and `data` is an NDArray of any + dtype with shape [D0, D1, .., DK]. If the index of a slice in the first dimension + doesn't appear in `indices`, its values are zeros. + + A RowSparseNDArray is typically used to represent a subset of a larger dense NDArray of + shape [LARGE0, D1, .. , DK] where LARGE0 >> D0 and most row slices are zeros. + + The indices are expected to be sorted in ascending order. RowSparseNDArray is used principally in the definition of gradients for operations that have sparse gradients (e.g. dot with sparse inputs). @@ -407,10 +417,10 @@ class RowSparseNDArray(BaseSparseNDArray): Examples -------- >>> import mxnet as mx - >>> dense = mx.nd.array([[1,2],[0,0],[3,0],[0,0]]) + >>> dense = mx.nd.array([[1,2],[0,0],[3,0],[0,0],[0,0],[0,0]]) >>> rsp = dense._to_rsp() >>> rsp.indices.asnumpy() - array([0, 2], dtype=int32) + array([0, 2], dtype=int64) >>> rsp.data.asnumpy() array([[ 1., 2.], [ 3., 0.]], dtype=float32) @@ -452,6 +462,9 @@ def indices(self): def _prepare_src_array(src, dtype, default_dtype): + """Prepare `src` and its dtype so that they can be used to construct NDArray. + `src` is converted to a `np.ndarray` if it's neither an `NDArray` nor an `np.ndarray`. + """ if isinstance(src, NDArray): dtype = src.dtype if dtype is None else dtype else: @@ -464,8 +477,9 @@ def _prepare_src_array(src, dtype, default_dtype): return src, dtype -def csr(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, indices_type=None): - """Creates a 2D array with compressed sparse row format. +def csr_matrix(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, + indices_type=None): + """Creates a 2D array with compressed sparse row(CSR) format. Parameters ---------- @@ -484,10 +498,10 @@ def csr(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, in if `values` is an `NDArray`, `float32` otherwise. indptr_type: str or numpy.dtype, optional The data type of the indices array. The default dtype is ``indptr.dtype`` - if `indptr` is an `NDArray`, `int32` otherwise. + if `indptr` is an `NDArray`, `int64` otherwise. indices_type: str or numpy.dtype, optional The data type of the indices array. The default dtype is ``indices.dtype`` - if `indicies` is an `NDArray`, `int32` otherwise. + if `indicies` is an `NDArray`, `int64` otherwise. Returns ------- @@ -497,7 +511,7 @@ def csr(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, in Example ------- >>> import mxnet as mx - >>> a = mx.nd.csr([1, 2, 3], [0, 1, 2, 2, 3], [1, 0, 2], (4, 3)) + >>> a = mx.nd.csr_matrix([1, 2, 3], [0, 1, 2, 2, 3], [1, 0, 2], (4, 3)) >>> a.asnumpy() array([[ 0., 1., 0.], [ 2., 0., 0.], @@ -540,13 +554,13 @@ def csr(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, in return result -def row_sparse(data, indices, shape, ctx=None, dtype=None, indices_type=None): - """Creates a row sparse array with a set of tensor slices at given indices. +def row_sparse_array(data, indices, shape, ctx=None, dtype=None, indices_type=None): + """Creates a K-dimensional row sparse array with a set of tensor slices at given indices. Parameters ---------- data: array_like - An object exposing the array interface, with shape [D0, D1, .. Dn], where D0 is + An object exposing the array interface, with shape [D0, D1, .. DK], where D0 is the number of rows with non-zeros entries. indices: array_like An object exposing the array interface, with shape [D0]. @@ -557,7 +571,7 @@ def row_sparse(data, indices, shape, ctx=None, dtype=None, indices_type=None): if `data` is an `NDArray`, `float32` otherwise. indices_type: str or numpy.dtype, optional The data type of the indices array. The default dtype is ``indices.dtype`` - if `indicies` is an `NDArray`, `int32` otherwise. + if `indicies` is an `NDArray`, `int64` otherwise. Returns ------- @@ -566,7 +580,7 @@ def row_sparse(data, indices, shape, ctx=None, dtype=None, indices_type=None): Example ------- - >>> a = mx.nd.row_sparse([[1, 2], [3, 4]], [1, 4], (6, 2)) + >>> a = mx.nd.row_sparse_array([[1, 2], [3, 4]], [1, 4], (6, 2)) >>> a.asnumpy() array([[ 0., 0.], [ 1., 2.], diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index d7c0574f962d..109ab8cbbe00 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -316,8 +316,8 @@ class SGD(Optimizer): state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight weight = weight - state - For details of the update algorithm see :class:`~mxnet.ndarray.sgd_update` and - :class:`~mxnet.ndarray.sgd_mom_update`. + Sparse updating is supported. For details of the update algorithm see + :class:`~mxnet.ndarray.sgd_update` and :class:`~mxnet.ndarray.sgd_mom_update`. This optimizer accepts the following parameters in addition to those accepted by :class:`.Optimizer`. diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 7cc15a3366c3..4f6ee9101796 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -79,7 +79,7 @@ def random_sample(population, k): return population_copy[0:k] -def rand_sparse_ndarray(shape, stype, density=None): +def rand_sparse_ndarray(shape, stype, density=None, dtype=None): """Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np) """ density = rnd.rand() if density is None else density if stype == 'row_sparse': @@ -87,26 +87,26 @@ def rand_sparse_ndarray(shape, stype, density=None): idx_sample = rnd.rand(shape[0]) indices = np.argwhere(idx_sample < density).flatten() if indices.shape[0] == 0: - result = mx.nd.zeros(shape, stype='row_sparse') + result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype) return result, (np.array([], dtype='int64'), np.array([], dtype='int64')) # generate random values - val = rnd.rand(indices.shape[0], *shape[1:]) - arr = mx.nd.row_sparse(val, indices, shape, indices_type=np.int64) + val = rnd.rand(indices.shape[0], *shape[1:]).astype(dtype) + arr = mx.nd.row_sparse_array(val, indices, shape, indices_type=np.int64, dtype=dtype) return arr, (val, indices) elif stype == 'csr': assert(len(shape) == 2) csr = sp.rand(shape[0], shape[1], density=density, format='csr') - result = mx.nd.csr(csr.data, csr.indptr, csr.indices, shape) + result = mx.nd.csr_matrix(csr.data, csr.indptr, csr.indices, shape, dtype=dtype) return result, (csr.indptr, csr.indices, csr.data) else: assert(False), "unknown storage type" -def rand_ndarray(shape, stype, density=None): +def rand_ndarray(shape, stype, density=None, dtype=None): if stype == 'default': - arr = mx.nd.array(random_arrays(shape)) + arr = mx.nd.array(random_arrays(shape), dtype=dtype) else: - arr, _ = rand_sparse_ndarray(shape, stype, density=density) + arr, _ = rand_sparse_ndarray(shape, stype, density=density, dtype=dtype) return arr diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index f112862f5048..4920b40c925a 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -643,6 +643,24 @@ int MXInvokeCachedOp(CachedOpHandle handle, API_END(); } +int MXInvokeCachedOpEx(CachedOpHandle handle, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs, + const int **out_stypes) { // outputs storage types + API_BEGIN(); + MXInvokeCachedOp(handle, num_inputs, inputs, num_outputs, outputs); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + NDArray** output_nds = reinterpret_cast(*outputs); + ret->out_types.resize(*num_outputs); + for (int i = 0; i < *num_outputs; ++i) { + ret->out_types[i] = output_nds[i]->storage_type(); + } + *out_stypes = dmlc::BeginPtr(ret->out_types); + API_END(); +} + int MXAutogradSetIsTraining(int is_training, int* prev) { API_BEGIN(); *prev = AutogradRuntime::Get()->SetIsTraining(static_cast(is_training)); diff --git a/src/common/utils.h b/src/common/utils.h index 86bc4a730d6b..7e1c24b3d519 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -43,7 +43,7 @@ inline bool GetDefaultBlobs(const std::vector& src, for (size_t i = 0; i < src.size(); i++) { auto& nd = src[i]; if (nd.storage_type() != kDefaultStorage) { - NDArray temp(nd.shape(), nd.ctx(), false); + NDArray temp(nd.shape(), nd.ctx(), false, nd.dtype()); temp_src->emplace_back(nd); temp_dst->emplace_back(temp); blobs->emplace_back(temp.data()); diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 0bf9b14f2f0e..63ea0001bea5 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1413,9 +1413,6 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning; #else bool profiling = false; -#endif -#if EXECUTOR_DEBUG - LOG(INFO) << "Run node " << nid << " - " << seg_op.topo_end - 1; #endif Engine::Get()->Push(seg_op.opr, seg_op.ctx, 0, profiling); nid = seg_op.topo_end - 1; @@ -1440,9 +1437,6 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning; #else bool profiling = false; -#endif -#if EXECUTOR_DEBUG - LOG(INFO) << "Run node " << nid; #endif Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling); } else { diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index 490851feb177..b4634aa2f74b 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -138,21 +138,6 @@ inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs); } -inline bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, - const Context& ctx, - std::vector *in_attrs, - std::vector *out_attrs) { - // TODO(junwu): add ctx info into storage inference logic - CHECK_EQ(in_attrs->size(), static_cast(2)) << " in operator " << attrs.name; - CHECK_EQ(out_attrs->size(), static_cast(1)) << " in operator " << attrs.name; - auto &in = *in_attrs; - auto &out = *out_attrs; - CHECK_NE(in[1], kUndefinedStorage) << "rhs storage type must be known"; - if (in[0] == kUndefinedStorage) in[0] = in[1]; - if (out[0] == kUndefinedStorage) out[0] = in[1]; - return true; -} - // Transfer gradient and input to FGradient function struct ElemwiseGradUseIn { const char *op_name; diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 14673904798a..af5fe7ea7952 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -241,7 +241,7 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs, } else if (weight_stype == kRowSparseStorage && grad_stype == kDefaultStorage) { NDArray out = outputs[0]; SGDUpdateRspDnsImpl(param, ctx, inputs[0], inputs[1].data(), req[0], &out); - } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage) { + } else { FCompExFallback(attrs, ctx, inputs, req, outputs, SGDUpdate, "SGDUpdate"); } } @@ -600,8 +600,7 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, mom_stype == kRowSparseStorage) { NDArray out = outputs[0]; SGDMomUpdateRspDnsImpl(param, ctx, weight, grad.data(), mom, req[0], &out); - } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage && - mom_stype == kDefaultStorage) { + } else { FCompExFallback(attrs, ctx, inputs, req, outputs, SGDMomUpdate, "SGDMomUpdate"); } } diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 63a4a3ddb795..976b8cab4755 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -142,6 +142,21 @@ void IdentityComputeEx(const nnvm::NodeAttrs& attrs, } } +inline bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + // TODO(junwu): add ctx info into storage inference logic + CHECK_EQ(in_attrs->size(), static_cast(2)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(1)) << " in operator " << attrs.name; + auto &in = *in_attrs; + auto &out = *out_attrs; + CHECK_NE(in[1], kUndefinedStorage) << "rhs storage type must be known"; + if (in[0] == kUndefinedStorage) STORAGE_TYPE_ASSIGN_CHECK(in, 0, in[1]); + if (out[0] == kUndefinedStorage) STORAGE_TYPE_ASSIGN_CHECK(out, 0, in[1]); + return true; +} + template void IdentityLikeRhsComputeEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 92ce08d3acda..98d6d69a0aac 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -487,8 +487,8 @@ def fm(factor_size, feature_dim, init): import scipy.sparse as sp # generate some random scipy csr data csr_sp = sp.rand(num_samples, feature_dim, density=0.1, format='csr') - csr_nd = mx.nd.csr(csr_sp.data, csr_sp.indptr, csr_sp.indices, - (num_samples, feature_dim)) + csr_nd = mx.nd.csr_matrix(csr_sp.data, csr_sp.indptr, csr_sp.indices, + (num_samples, feature_dim)) label = mx.nd.ones((num_samples,1)) # the alternative is to use LibSVMIter train_iter = mx.io.NDArrayIter(data=csr_nd, diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 8399194efbba..c7e48090d404 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -35,7 +35,7 @@ def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='defa w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) w1 = w2.copyto(default_context()) elif w_stype == 'row_sparse': - w2 = rand_ndarray(shape, w_stype, density=1) + w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype) w1 = w2.copyto(default_context()).todense() else: raise Exception("type not supported yet") @@ -43,7 +43,7 @@ def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='defa g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) g1 = g2.copyto(default_context()) elif g_stype == 'row_sparse': - g2 = rand_ndarray(shape, g_stype) + g2 = rand_ndarray(shape, g_stype, dtype=dtype) g1 = g2.copyto(default_context()).todense() else: raise Exception("type not supported yet") @@ -260,24 +260,28 @@ def test_sparse_sgd(): mx.random.seed(0) opt1 = PySparseSGD opt2 = mx.optimizer.SGD - shape = (3, 4) - kwargs = [{}, - {'momentum': 0.9}, - {'clip_gradient': 0.5}, - {'clip_gradient': 0.4, 'rescale_grad': 0.14}, - {'rescale_grad': 0.8}, - {'clip_gradient': 0.5, 'wd': 0.07}, - {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'wd': 0.03}, - {'rescale_grad': 0.8, 'wd': 0.05}, - {'clip_gradient': 0.5, 'momentum': 0.9}, - {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'momentum': 0.9}, - {'rescale_grad': 0.8, 'momentum': 0.9}, - {'clip_gradient': 0.5, 'wd': 0.07, 'momentum': 0.9}, - {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'wd': 0.03, 'momentum': 0.9}, - {'rescale_grad': 0.8, 'wd': 0.05, 'momentum': 0.9}] - for kwarg in kwargs: - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, 'float32', w_stype='row_sparse', g_stype='row_sparse') - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, 'float32', w_stype='row_sparse', g_stype='default') + shape = (3, 4, 5) + mom_options = [{}, {'momentum': 0.9}] + cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}] + rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] + wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] + mp_options = [{}] + for dtype in [np.float32]: + for mom_option in mom_options: + for cg_option in cg_options: + for rg_option in rg_options: + for wd_option in wd_options: + for mp_option in mp_options: + kwarg = {} + kwarg.update(mom_option) + kwarg.update(cg_option) + kwarg.update(rg_option) + kwarg.update(wd_option) + kwarg.update(mp_option) + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, + w_stype='row_sparse', g_stype='row_sparse') + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, + w_stype='row_sparse', g_stype='default') # ADAM diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 06ab437df226..4c98cea685cb 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -421,7 +421,7 @@ def test_create_csr(): data = matrix.data indptr = matrix.indptr indices = matrix.indices - csr_created = mx.nd.csr(data=data, indptr=indptr, indices=indices, shape=shape) + csr_created = mx.nd.csr_matrix(data=data, indptr=indptr, indices=indices, shape=shape) assert csr_created.stype == 'csr' assert same(csr_created.data.asnumpy(), data.asnumpy()) assert same(csr_created.indptr.asnumpy(), indptr.asnumpy()) @@ -439,7 +439,7 @@ def test_create_row_sparse(): matrix = rand_ndarray(shape, 'row_sparse', density) data = matrix.data indices = matrix.indices - rsp_created = mx.nd.row_sparse(data=data, indices=indices, shape=shape) + rsp_created = mx.nd.row_sparse_array(data=data, indices=indices, shape=shape) assert rsp_created.stype == 'row_sparse' assert same(rsp_created.data.asnumpy(), data.asnumpy()) assert same(rsp_created.indices.asnumpy(), indices.asnumpy()) diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index dea5b99f05b0..9cd00f8de08e 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -45,8 +45,8 @@ def test_elemwise_add_ex_multiple_stages(): val2 = mx.nd.array([[5, 10]]); idx1 = mx.nd.array([0], dtype=np.int64); idx2 = mx.nd.array([1], dtype=np.int64); - sp_nd1 = mx.nd.row_sparse(val1, idx1, shape) - sp_nd2 = mx.nd.row_sparse(val2, idx2, shape) + sp_nd1 = mx.nd.row_sparse_array(val1, idx1, shape) + sp_nd2 = mx.nd.row_sparse_array(val2, idx2, shape) ds_nd = mx.nd.array(ds_np) # sparse + sparse = sparse