diff --git a/benchmark/python/cast_storage.py b/benchmark/python/cast_storage.py index 38398e5e164a..dbe0cdca1716 100644 --- a/benchmark/python/cast_storage.py +++ b/benchmark/python/cast_storage.py @@ -26,7 +26,7 @@ def run_cast_storage_synthetic(): def dense_to_sparse(m, n, density, ctx, repeat, stype): set_default_context(ctx) data_shape = (m, n) - dns_data = rand_ndarray(data_shape, stype, density).todense() + dns_data = rand_ndarray(data_shape, stype, density).tostype('default') dns_data.wait_to_read() # do one warm up run, verify correctness diff --git a/benchmark/python/sparse_op.py b/benchmark/python/sparse_op.py index 15ca4df1be73..f01d70dbfd13 100644 --- a/benchmark/python/sparse_op.py +++ b/benchmark/python/sparse_op.py @@ -92,7 +92,7 @@ def get_iter(path, data_shape, batch_size): for batch in train_iter: data = train_iter.getdata() csr_data.append(data) - dns_data.append(data.todense()) + dns_data.append(data.tostype('default')) num_batch += 1 bag_of_data = [csr_data, dns_data] num_repeat = 5 @@ -140,10 +140,10 @@ def bench_dot_forward(m, k, n, density, ctx, repeat): dns = mx.nd.random_uniform(shape=(k, n)).copyto(ctx) data_shape = (m, k) csr_data = rand_ndarray(data_shape, 'csr', density) - dns_data = csr_data.todense() + dns_data = csr_data.tostype('default') rhs_dns_np = dns.asnumpy() lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) # csr in scipy - lhs_dns_np = lhs_csr_sp.todense() + lhs_dns_np = lhs_csr_sp.tostype('default') data = [dns_data, csr_data] costs = [] @@ -169,10 +169,10 @@ def bench_dot_backward(m, k, n, density, ctx, repeat): dns = mx.nd.random_uniform(shape=(m, n)).copyto(ctx) data_shape = (m, k) csr_data = rand_ndarray(data_shape, 'csr', density) - dns_data = csr_data.todense() + dns_data = csr_data.tostype('default') rhs_dns_np = dns.asnumpy() lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) - lhs_dns_np = lhs_csr_sp.todense() + lhs_dns_np = lhs_csr_sp.tostype('default') data = [dns_data, csr_data] costs = [] diff --git a/docs/api/python/ndarray.md b/docs/api/python/ndarray.md index e7361cd2683c..c38e6e301798 100644 --- a/docs/api/python/ndarray.md +++ b/docs/api/python/ndarray.md @@ -107,6 +107,7 @@ We summarize the interface for each class in the following sections. NDArray.asnumpy NDArray.asscalar NDArray.astype + NDArray.tostype ``` ### Array change shape @@ -191,6 +192,7 @@ We summarize the interface for each class in the following sections. :nosignatures: RowSparseNDArray.copyto + RowSparseNDArray.tostype RowSparseNDArray.__setitem__ RowSparseNDArray.__getitem__ RowSparseNDArray.data @@ -204,6 +206,7 @@ We summarize the interface for each class in the following sections. :nosignatures: CSRNDArray.copyto + CSRNDArray.tostype CSRNDArray.__setitem__ CSRNDArray.__getitem__ CSRNDArray.data diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index a7d6e2033df2..e76057a1c437 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -719,6 +719,30 @@ MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, * \brief free cached operator */ MXNET_DLL int MXFreeCachedOp(CachedOpHandle handle); +/*! + * \brief invoke cached operator + */ +MXNET_DLL int MXInvokeCachedOp(CachedOpHandle handle, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs); +/*! + * \brief invoke a cached op + * \param handle the handle to the cached op + * \param num_inputs number of input NDArrays + * \param inputs input NDArrays + * \param num_outputs number of output NDArrays + * \param outputs output NDArrays + * \param out_stypes output ndarrays' stypes + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs, + const int** out_stypes); /*! * \brief invoke cached operator */ diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 2215acb3423e..61e0d612f31b 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -193,13 +193,16 @@ class NDArray { } /*! - * \return the shape of aux data at ith index. If it doesn't exist, return an empty one. + * \brief get the shape of aux_data(index) + * \param index the index of the aux data + * \return the shape of aux data at given index */ - inline const TShape aux_shape(size_t i) const { + inline const TShape& aux_shape(size_t index) const { CHECK(storage_type() != kDefaultStorage); - return ptr_->aux_shapes[i]; + return ptr_->aux_shapes[index]; } + /* \return the shapes of all aux data */ const std::vector& aux_shapes() const { CHECK(storage_type() != kDefaultStorage); return ptr_->aux_shapes; @@ -212,8 +215,8 @@ class NDArray { * for the final result. After the operation is done, the exact size of * the shape is known and need to be reset using this function. */ - inline void set_aux_shape(size_t i, const TShape& shape) const { - ptr_->set_aux_shape(i, shape); + inline void set_aux_shape(size_t index, const TShape& shape) const { + ptr_->set_aux_shape(index, shape); } /*! @@ -851,7 +854,6 @@ size_t num_aux_data(NDArrayStorageType stype); * \param from the ndarray we want to copy data from * \param to the target ndarray * \param priority Priority of the action. - * \param alloc_output whether to allocate memory for the output ndarray * \note The function name explicitly marks the order of from and to * due to different possible convention carried by copy function. */ diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index 2a82c0899ea2..622d843bb92b 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -127,17 +127,24 @@ def __call__(self, *args, **kwargs): "CachedOp.__call__ got unexpected keyword argument(s): " + \ ', '.join(kwargs.keys())) - check_call(_LIB.MXInvokeCachedOp( + # return output stypes to avoid the c_api call for checking + # a handle's stype in _ndarray_cls + out_stypes = ctypes.POINTER(ctypes.c_int)() + + check_call(_LIB.MXInvokeCachedOpEx( self.handle, ctypes.c_int(len(args)), c_array(NDArrayHandle, [arr.handle for arr in args]), ctypes.byref(num_output), - ctypes.byref(output_vars))) + ctypes.byref(output_vars), + ctypes.byref(out_stypes))) if original_output is not None: return original_output if num_output.value == 1: - return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle)) + return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle), + stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]]) else: - return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle)) + return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), + stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]]) for i in range(num_output.value)] diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 8d96c751ccb3..18f612e1d0e9 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -255,11 +255,12 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): row_ids : NDArray or list of NDArray The row_ids for which to pull for each value. The row_ids doesn't have to be unique or sorted. + Examples -------- >>> shape = (3, 3) - >>> kv.init('3', mx.nd.ones(shape)._to_rsp()) - >>> a = mx.nd.zeros(shape) + >>> kv.init('3', mx.nd.ones(shape).tostype('row_sparse')) + >>> a = mx.nd.zeros(shape, stype='row_sparse') >>> row_ids = mx.nd.array([0, 2], dtype='int64') >>> kv.row_sparse_pull('3', out=a, row_ids=row_ids) >>> print a.asnumpy() diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py index 7a55246c9b7a..1220a6600d17 100644 --- a/python/mxnet/ndarray/__init__.py +++ b/python/mxnet/ndarray/__init__.py @@ -5,6 +5,6 @@ from .op import CachedOp # pylint: disable=wildcard-import from .ndarray import * -from .ndarray_utils import load, save, zeros, empty, array -from .sparse_ndarray import _ndarray_cls, todense -from .sparse_ndarray import csr, row_sparse, BaseSparseNDArray, RowSparseNDArray, CSRNDArray +from .utils import load, save, zeros, empty, array +from .sparse_ndarray import _ndarray_cls, csr_matrix, row_sparse_array +from .sparse_ndarray import BaseSparseNDArray, RowSparseNDArray, CSRNDArray diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index d756333b9847..c13132d90524 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -1089,13 +1089,15 @@ def backward(self, out_grad=None, retain_graph=False, is_train=True): ctypes.c_int(retain_graph), ctypes.c_int(is_train))) - def _to_csr(self): - # pylint: disable=undefined-variable - return cast_storage(self, stype='csr') + def tostype(self, stype): + """Return a copy of the array with chosen storage type. - def _to_rsp(self): - # pylint: disable=undefined-variable - return cast_storage(self, stype='row_sparse') + Returns + ------- + NDArray, CSRNDArray or RowSparseNDArray + A copy of the array with the chosen storage stype + """ + return cast_storage(self, stype=stype) def onehot_encode(indices, out): """One-hot encoding indices into matrix out. diff --git a/python/mxnet/ndarray/sparse_ndarray.py b/python/mxnet/ndarray/sparse_ndarray.py index 88393be66c25..37fd4ffd0113 100644 --- a/python/mxnet/ndarray/sparse_ndarray.py +++ b/python/mxnet/ndarray/sparse_ndarray.py @@ -144,7 +144,7 @@ def _aux_types(self): def asnumpy(self): """Return a dense ``numpy.ndarray`` object with value copied from this array """ - return self.todense().asnumpy() + return self.tostype('default').asnumpy() def astype(self, dtype): """Returns a copy of the array after casting to a specified type. @@ -189,9 +189,6 @@ def copyto(self, other): else: raise TypeError('copyto does not support type ' + str(type(other))) - def todense(self): - return todense(self) - def _data(self): """A deep copy NDArray of the data array associated with the BaseSparseNDArray. @@ -217,7 +214,9 @@ def _aux_data(self, i): # pylint: disable=abstract-method class CSRNDArray(BaseSparseNDArray): - """A CSRNDArray represents a NDArray as three separate arrays: `data`, + """A sparse representation of 2D NDArray in the standard CSR format. + + A CSRNDArray represents an NDArray as three separate arrays: `data`, `indptr` and `indices`. It uses the standard CSR representation where the column indices for row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored in values[indptr[i]:indptr[i+1]]. @@ -225,7 +224,7 @@ class CSRNDArray(BaseSparseNDArray): Example ------- >>> a = mx.nd.array([[0, 1, 0], [2, 0, 0], [0, 0, 0], [0, 0, 3]]) - >>> a = a._to_csr() + >>> a = a.tostype('csr') >>> a.indices.asnumpy() array([1, 0, 2]) >>> a.indptr.asnumpy() @@ -269,13 +268,16 @@ def __getitem__(self, key): Examples -------- - >>> x = mx.nd.zeros((2, 3), stype='csr') - >>> x[:] = mx.nd.arange(0,6).reshape((2,3)) - >>> x.asnumpy() - array([[ 0., 1., 2.], - [ 3., 4., 5.]], dtype=float32) - >>> x[1:2].asnumpy() - array([[ 3., 4., 5.]], dtype=float32) + >>> indptr = np.array([0, 2, 3, 6]) + >>> indices = np.array([0, 2, 2, 0, 1, 2]) + >>> data = np.array([1, 2, 3, 4, 5, 6]) + >>> a = mx.nd.csr_matrix(data, indptr, indices, (3, 3)) + >>> a.asnumpy() + array([[1, 0, 2], + [0, 0, 3], + [4, 5, 6]]) + >>> a[1:2].asnumpy() + array([[0, 0, 3]], dtype=float32) """ if isinstance(key, int): raise ValueError("__getitem__ with int key is not implemented for CSRNDArray") @@ -311,7 +313,7 @@ def __setitem__(self, key, value): [ 0., 0., 0.], [ 0., 0., 0.]], dtype=float32) >>> # assign CSRNDArray with same storage type - >>> x = mx.nd.ones('row_sparse', (3,3))._to_csr() + >>> x = mx.nd.ones('row_sparse', (3,3)).tostype('csr') >>> x[:] = src >>> x.asnumpy() array([[ 1., 1., 1.], @@ -386,6 +388,18 @@ def data(self): """ return self._data() + def tostype(self, stype): + """Return a copy of the array with chosen storage type. + + Returns + ------- + NDArray or CSRNDArray + A copy of the array with the chosen storage stype + """ + if stype == 'row_sparse': + raise ValueError("cast_storage from csr to row_sparse is not supported") + return cast_storage(self, stype=stype) + def copyto(self, other): """Copies the value of this array to another array. @@ -420,29 +434,40 @@ def copyto(self, other): # pylint: disable=abstract-method class RowSparseNDArray(BaseSparseNDArray): - """A RowSparseNDArray is typically used to represent a subset of a larger - NDArray with `default` of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0. The values - in indices are the indices in the first dimension of the slices that have been extracted from - the larger NDArray. The indices are expected to be sorted in ascending order. + """A sparse representation of a set of NDArray row slices at given indices. - The corresponding NDArray ``dense`` with `default` storage represented by a ``rsp`` - RowSparseNDArray + A RowSparseNDArray represents a multidimensional NDArray using two separate arrays: `data` and + `indices`. - ``dense[rsp.indices[i], :, :, :, ...] = rsp.values[i, :, :, :, ...]`` + - data: an NDArray of any dtype with shape [D0, D1, ..., Dn]. + - indices: a 1-D int64 NDArray with shape [D0]. - RowSparseNDArray is used principally in the definition of gradients for operations - that have sparse gradients (e.g. dot with sparse inputs). + The `indices` stores the indices of the row slices with non-zeros, + while the values are stored in `data`. The corresponding NDArray ``dense`` + represented by RowSparseNDArray ``rsp`` has - Examples - -------- - >>> import mxnet as mx - >>> dense = mx.nd.array([[1,2],[0,0],[3,0],[0,0]]) - >>> rsp = dense._to_rsp() - >>> rsp.indices.asnumpy() - array([0, 2], dtype=int32) - >>> rsp.data.asnumpy() - array([[ 1., 2.], - [ 3., 0.]], dtype=float32) + ``dense[rsp.indices[i], :, :, :, ...] = rsp.data[i, :, :, :, ...]`` + + >>> dense.asnumpy() + array([[ 1., 2., 3.], + [ 0., 0., 0.], + [ 4., 0., 5.], + [ 0., 0., 0.], + [ 0., 0., 0.]], dtype=float32) + >>> rsp = dense.tostype('row_sparse') + >>> rsp.indices.asnumpy() + array([0, 2], dtype=int64) + >>> rsp.data.asnumpy() + array([[ 1., 2., 3.], + [ 4., 0., 5.]], dtype=float32) + + A RowSparseNDArray is typically used to represent non-zero row-slices of a large NDArray + of shape [LARGE0, D1, .. , Dn] where LARGE0 >> D0 and most row slices are zeros. + + The indices are expected to be sorted in ascending order. + + RowSparseNDArray is used principally in the definition of gradients for operations + that have sparse gradients (e.g. sparse dot and sparse embedding). """ def __reduce__(self): return RowSparseNDArray, (None,), super(RowSparseNDArray, self).__getstate__() @@ -549,7 +574,7 @@ def __setitem__(self, key, value): raise TypeError('type %s not supported' % str(type(value))) else: assert(isinstance(key, (int, tuple))) - raise Exception('RowSparseNDArray only supports [:] for assignment') + raise TypeError('RowSparseNDArray only supports [:] for assignment') @property def indices(self): @@ -575,6 +600,18 @@ def data(self): """ return self._data() + def tostype(self, stype): + """Return a copy of the array with chosen storage type. + + Returns + ------- + NDArray or RowSparseNDArray + A copy of the array with the chosen storage stype + """ + if stype == 'csr': + raise ValueError("cast_storage from row_sparse to csr is not supported") + return cast_storage(self, stype=stype) + def copyto(self, other): """Copies the value of this array to another array. @@ -608,6 +645,9 @@ def copyto(self, other): raise TypeError('copyto does not support type ' + str(type(other))) def _prepare_src_array(src, dtype, default_dtype): + """Prepare `src` and its dtype so that they can be used to construct NDArray. + `src` is converted to a `np.ndarray` if it's neither an `NDArray` nor an `np.ndarray`. + """ if isinstance(src, NDArray): dtype = src.dtype if dtype is None else dtype else: @@ -620,8 +660,9 @@ def _prepare_src_array(src, dtype, default_dtype): return src, dtype -def csr(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, indices_type=None): - """Creates a 2D array with compressed sparse row format. +def csr_matrix(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, + indices_type=None): + """Creates a 2D array with compressed sparse row(CSR) format. Parameters ---------- @@ -640,10 +681,10 @@ def csr(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, in if `values` is an `NDArray`, `float32` otherwise. indptr_type: str or numpy.dtype, optional The data type of the indices array. The default dtype is ``indptr.dtype`` - if `indptr` is an `NDArray`, `int32` otherwise. + if `indptr` is an `NDArray`, `int64` otherwise. indices_type: str or numpy.dtype, optional The data type of the indices array. The default dtype is ``indices.dtype`` - if `indicies` is an `NDArray`, `int32` otherwise. + if `indicies` is an `NDArray`, `int64` otherwise. Returns ------- @@ -653,7 +694,7 @@ def csr(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, in Example ------- >>> import mxnet as mx - >>> a = mx.nd.csr([1, 2, 3], [0, 1, 2, 2, 3], [1, 0, 2], (4, 3)) + >>> a = mx.nd.csr_matrix([1, 2, 3], [0, 1, 2, 2, 3], [1, 0, 2], (4, 3)) >>> a.asnumpy() array([[ 0., 1., 0.], [ 2., 0., 0.], @@ -696,13 +737,13 @@ def csr(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, in return result -def row_sparse(data, indices, shape, ctx=None, dtype=None, indices_type=None): - """Creates a row sparse array with a set of tensor slices at given indices. +def row_sparse_array(data, indices, shape, ctx=None, dtype=None, indices_type=None): + """Creates a multidimensional row sparse array with a set of tensor slices at given indices. Parameters ---------- data: array_like - An object exposing the array interface, with shape [D0, D1, .. Dn], where D0 is + An object exposing the array interface, with shape [D0, D1, .. DK], where D0 is the number of rows with non-zeros entries. indices: array_like An object exposing the array interface, with shape [D0]. @@ -713,7 +754,7 @@ def row_sparse(data, indices, shape, ctx=None, dtype=None, indices_type=None): if `data` is an `NDArray`, `float32` otherwise. indices_type: str or numpy.dtype, optional The data type of the indices array. The default dtype is ``indices.dtype`` - if `indicies` is an `NDArray`, `int32` otherwise. + if `indicies` is an `NDArray`, `int64` otherwise. Returns ------- @@ -722,7 +763,7 @@ def row_sparse(data, indices, shape, ctx=None, dtype=None, indices_type=None): Example ------- - >>> a = mx.nd.row_sparse([[1, 2], [3, 4]], [1, 4], (6, 2)) + >>> a = mx.nd.row_sparse_array([[1, 2], [3, 4]], [1, 4], (6, 2)) >>> a.asnumpy() array([[ 0., 0.], [ 1., 2.], @@ -758,18 +799,6 @@ def row_sparse(data, indices, shape, ctx=None, dtype=None, indices_type=None): check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle, indices.handle, ctypes.c_int(0))) return result - -def todense(source): - """ Return a dense array representation of a BaseSparseNDArray. - - Returns - ------- - NDArray - A copy of the array with `default` storage stype - """ - return cast_storage(source, stype='default') - - def _ndarray_cls(handle, writable=True, stype=None): if stype is None: stype = _storage_type(handle) diff --git a/python/mxnet/ndarray/ndarray_utils.py b/python/mxnet/ndarray/utils.py similarity index 100% rename from python/mxnet/ndarray/ndarray_utils.py rename to python/mxnet/ndarray/utils.py diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index bfcde757e77e..6ecbaed6be86 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -322,8 +322,8 @@ class SGD(Optimizer): state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight weight = weight - state - For details of the update algorithm see :class:`~mxnet.ndarray.sgd_update` and - :class:`~mxnet.ndarray.sgd_mom_update`. + Sparse updating is supported. For details of the update algorithm see + :class:`~mxnet.ndarray.sgd_update` and :class:`~mxnet.ndarray.sgd_mom_update`. This optimizer accepts the following parameters in addition to those accepted by :class:`.Optimizer`. @@ -348,8 +348,6 @@ def create_state(self, index, weight): momentum = None weight_master_copy = None if self.multi_precision and weight.dtype == numpy.float16: - assert(weight.stype == 'default'), \ - "multi-precision doesn't supprot non-default weight yet" weight_master_copy = array(weight, ctx=weight.context, dtype=numpy.float32) if self.momentum != 0.0: momentum = zeros(weight.shape, weight.context, dtype=numpy.float32, diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 7cc15a3366c3..8b6985ec5582 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -79,34 +79,35 @@ def random_sample(population, k): return population_copy[0:k] -def rand_sparse_ndarray(shape, stype, density=None): +def rand_sparse_ndarray(shape, stype, density=None, dtype=None): """Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np) """ density = rnd.rand() if density is None else density + dtype = default_dtype() if dtype is None else dtype if stype == 'row_sparse': # sample index idx_sample = rnd.rand(shape[0]) indices = np.argwhere(idx_sample < density).flatten() if indices.shape[0] == 0: - result = mx.nd.zeros(shape, stype='row_sparse') - return result, (np.array([], dtype='int64'), np.array([], dtype='int64')) + result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype) + return result, (np.array([], dtype=dtype), np.array([], dtype='int64')) # generate random values - val = rnd.rand(indices.shape[0], *shape[1:]) - arr = mx.nd.row_sparse(val, indices, shape, indices_type=np.int64) + val = rnd.rand(indices.shape[0], *shape[1:]).astype(dtype) + arr = mx.nd.row_sparse_array(val, indices, shape, indices_type=np.int64, dtype=dtype) return arr, (val, indices) elif stype == 'csr': assert(len(shape) == 2) - csr = sp.rand(shape[0], shape[1], density=density, format='csr') - result = mx.nd.csr(csr.data, csr.indptr, csr.indices, shape) + csr = sp.rand(shape[0], shape[1], density=density, format='csr', dtype=dtype) + result = mx.nd.csr_matrix(csr.data, csr.indptr, csr.indices, shape, dtype=dtype) return result, (csr.indptr, csr.indices, csr.data) else: assert(False), "unknown storage type" -def rand_ndarray(shape, stype, density=None): +def rand_ndarray(shape, stype, density=None, dtype=None): if stype == 'default': - arr = mx.nd.array(random_arrays(shape)) + arr = mx.nd.array(random_arrays(shape), dtype=dtype) else: - arr, _ = rand_sparse_ndarray(shape, stype, density=density) + arr, _ = rand_sparse_ndarray(shape, stype, density=density, dtype=dtype) return arr @@ -731,7 +732,7 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= for k, v in args_grad_npy.items(): nd = mx.nd.array(v, ctx=ctx) if grad_stypes is not None and k in grad_stypes: - out = mx.nd.cast_storage(nd, stype=grad_stypes[k]) + out = nd.tostype(grad_stypes[k]) args_grad_data[k] = out else: args_grad_data[k] = nd diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 3515558f8900..f09303cb031f 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -668,6 +668,24 @@ int MXInvokeCachedOp(CachedOpHandle handle, API_END(); } +int MXInvokeCachedOpEx(CachedOpHandle handle, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs, + const int **out_stypes) { // outputs storage types + API_BEGIN(); + MXInvokeCachedOp(handle, num_inputs, inputs, num_outputs, outputs); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + NDArray** output_nds = reinterpret_cast(*outputs); + ret->out_types.resize(*num_outputs); + for (int i = 0; i < *num_outputs; ++i) { + ret->out_types[i] = output_nds[i]->storage_type(); + } + *out_stypes = dmlc::BeginPtr(ret->out_types); + API_END(); +} + int MXAutogradSetIsTraining(int is_training, int* prev) { API_BEGIN(); *prev = AutogradRuntime::Get()->SetIsTraining(static_cast(is_training)); diff --git a/src/common/utils.h b/src/common/utils.h index b41510ccb091..c2405f908492 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -55,7 +55,7 @@ inline bool SetupDefaultBlobs(const std::vector& src, if (idx_map != nullptr) { (*idx_map)[i] = temp_dst->size(); } - NDArray temp(nd.shape(), nd.ctx(), false); + NDArray temp(nd.shape(), nd.ctx(), false, nd.dtype()); temp_src->emplace_back(nd); temp_dst->emplace_back(temp); blobs->emplace_back(temp.data()); diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 0bf9b14f2f0e..e9e4c3001c41 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -16,6 +16,11 @@ namespace mxnet { namespace exec { + +GraphExecutor::GraphExecutor() { + log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false); +} + GraphExecutor::~GraphExecutor() { for (auto& n : op_nodes_) { if (n.cached_opr != nullptr) { @@ -531,10 +536,10 @@ void GraphExecutor::Init(nnvm::Symbol symbol, } ++arg_top; } -#if EXECUTOR_DEBUG - LOG(INFO) << "\tassign data entry\t" << eid << " as stype " - << data_entry_[eid].storage_type() << " (input)"; -#endif + if (log_verbose_) { + LOG(INFO) << "\tassign data entry\t" << eid << " as stype " + << data_entry_[eid].storage_type() << " (input)"; + } } // expand arg_shapes and arg_dtypes to contain backward inputs @@ -599,16 +604,16 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, data_entry_[eid] = aux_state_vec->back(); aux_state_map_.emplace(arg_name, aux_state_vec->back()); ++aux_top; -#if EXECUTOR_DEBUG - LOG(INFO) << "\tassign aux entry\t" << eid << "\t as stype " << inferred_stype; -#endif + if (log_verbose_) { + LOG(INFO) << "\tassign aux entry\t" << eid << "\t as stype " << inferred_stype; + } } else { // in_args EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top], inferred_dtype, in_arg_vec); data_entry_[eid] = in_arg_vec->back(); -#if EXECUTOR_DEBUG - LOG(INFO) << "\tassign data entry\t" << eid << "\tas stype " << inferred_stype; -#endif + if (log_verbose_) { + LOG(INFO) << "\tassign data entry\t" << eid << "\tas stype " << inferred_stype; + } // Get the storage type for grad if (kNullOp == grad_req_types[arg_top]) { arg_grad_vec->emplace_back(); @@ -619,9 +624,9 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top], inferred_dtype, arg_grad_vec); -#if EXECUTOR_DEBUG - LOG(INFO) << "\tassign grad entry\t" << grad_eid << "\tas stype " << grad_stype; -#endif + if (log_verbose_) { + LOG(INFO) << "\tassign grad entry\t" << grad_eid << "\tas stype " << grad_stype; + } grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); arg_grad_map_.emplace(arg_name, arg_grad_vec->back()); } @@ -1066,9 +1071,9 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { } else { data_entry_[data_eid] = NDArray(vshape[eid], data_context[eid], false, vdtype[eid]); } -#if EXECUTOR_DEBUG - LOG(INFO) << "\tinit head_g entry\t" << data_eid << "\tas stype " << stype; -#endif + if (log_verbose_) { + LOG(INFO) << "\tinit head_g entry\t" << data_eid << "\tas stype " << stype; + } } // get maximum bytes in each pool for (size_t i = 0; i < vshape.size(); ++i) { @@ -1151,9 +1156,9 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { } else { data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i]); } -#if EXECUTOR_DEBUG - LOG(INFO) << "\tinit data entry\t" << i << "\tas stype " << storage_type; -#endif + if (log_verbose_) { + LOG(INFO) << "\tinit data entry\t" << i << "\tas stype " << storage_type; + } } } @@ -1174,22 +1179,22 @@ void GraphExecutor::InitCachedOps() { // setup the array and requirements. for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; -#if EXECUTOR_DEBUG - if (inode.source->is_variable()) { - LOG(INFO) << "node " << nid << " var"; - } else { - LOG(INFO) << "node " << nid << " " << inode.source->attrs.op->name; - auto exec = op_execs[nid]; - for (const auto& e : inode.inputs) { - auto eid = idx.entry_id(e); - LOG(INFO) << "\t\tinput " << eid << " stype: " << vstorage_type[eid]; - } - for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { - uint32_t eid = idx.entry_id(nid, index); - LOG(INFO) << "\t\toutput " << eid << " stype: " << vstorage_type[eid]; + if (log_verbose_) { + if (inode.source->is_variable()) { + LOG(INFO) << "node " << nid << " var"; + } else { + LOG(INFO) << "node " << nid << " " << inode.source->attrs.op->name; + auto exec = op_execs[nid]; + for (const auto& e : inode.inputs) { + auto eid = idx.entry_id(e); + LOG(INFO) << "\t\tinput " << eid << " stype: " << vstorage_type[eid]; + } + for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { + uint32_t eid = idx.entry_id(nid, index); + LOG(INFO) << "\t\toutput " << eid << " stype: " << vstorage_type[eid]; + } } } -#endif if (inode.source->is_variable()) continue; #if MXNET_USE_PROFILER op_nodes_[nid].opr_name = inode.source->op()->name.c_str(); @@ -1413,9 +1418,6 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning; #else bool profiling = false; -#endif -#if EXECUTOR_DEBUG - LOG(INFO) << "Run node " << nid << " - " << seg_op.topo_end - 1; #endif Engine::Get()->Push(seg_op.opr, seg_op.ctx, 0, profiling); nid = seg_op.topo_end - 1; @@ -1440,9 +1442,6 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning; #else bool profiling = false; -#endif -#if EXECUTOR_DEBUG - LOG(INFO) << "Run node " << nid; #endif Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling); } else { diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index e4bcdd323fc6..0eb9acad02be 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -43,6 +43,7 @@ class GraphExecutor : public Executor { friend class autograd::AutogradRuntime; using Executor::MonitorCallback; + GraphExecutor(); virtual ~GraphExecutor(); void Forward(bool is_train) override; void PartialForward(bool is_train, int step, int *step_left) override; @@ -221,6 +222,8 @@ class GraphExecutor : public Executor { bool prefer_bulk_execution_; // cached segment operator std::vector cached_seg_opr_; + // verbose logging + bool log_verbose_ = false; }; } // namespace exec diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index 490851feb177..b4634aa2f74b 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -138,21 +138,6 @@ inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs); } -inline bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, - const Context& ctx, - std::vector *in_attrs, - std::vector *out_attrs) { - // TODO(junwu): add ctx info into storage inference logic - CHECK_EQ(in_attrs->size(), static_cast(2)) << " in operator " << attrs.name; - CHECK_EQ(out_attrs->size(), static_cast(1)) << " in operator " << attrs.name; - auto &in = *in_attrs; - auto &out = *out_attrs; - CHECK_NE(in[1], kUndefinedStorage) << "rhs storage type must be known"; - if (in[0] == kUndefinedStorage) in[0] = in[1]; - if (out[0] == kUndefinedStorage) out[0] = in[1]; - return true; -} - // Transfer gradient and input to FGradient function struct ElemwiseGradUseIn { const char *op_name; diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 270aae592f14..3137b3dae6b4 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -349,6 +349,16 @@ inline void ParamParser(nnvm::NodeAttrs* attrs) { attrs->parsed = std::move(param); } +/*! \brief Perform storage fallback to invoke fcompute. + * \param attrs attributes of the operator + * \param ctx operator context + * \param inputs inputs of fcompute + * \param req req of fcompute + * \param outputs outputs of fcompute + * \param fcompute + * \param fname name of the operator + * \param mutate_idx the indices of mutable inputs + */ template void FCompExFallback(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -356,15 +366,25 @@ void FCompExFallback(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs, FCompute fcompute, - const std::string& fname) { + const std::string& fname, + std::vector mutate_idx = {}) { using namespace mxnet::common; std::vector in_blobs, out_blobs; - std::vector temp_in_src, temp_in_dst, temp_out_src, temp_out_dst; - SetupDefaultBlobs(inputs, &in_blobs, &temp_in_src, &temp_in_dst); - SetupDefaultBlobs(outputs, &out_blobs, &temp_out_src, &temp_out_dst); - CastNonDefaultStorage(temp_in_src, temp_in_dst, ctx, true); + std::vector pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src; + // mapping from index in input_blobs to index in pre_temp_dst + std::unordered_map in_temp_idx_map; + SetupDefaultBlobs(inputs, &in_blobs, &pre_temp_src, &pre_temp_dst, &in_temp_idx_map); + SetupDefaultBlobs(outputs, &out_blobs, &post_temp_dst, &post_temp_src); + for (const auto idx : mutate_idx) { + auto map_iter = in_temp_idx_map.find(idx); + if (map_iter != in_temp_idx_map.end()) { + post_temp_src.push_back(pre_temp_dst[map_iter->second]); + post_temp_dst.push_back(inputs[idx]); + } + } + CastNonDefaultStorage(pre_temp_src, pre_temp_dst, ctx, true); fcompute(attrs, ctx, in_blobs, req, out_blobs); - CastNonDefaultStorage(temp_out_dst, temp_out_src, ctx, true); + CastNonDefaultStorage(post_temp_src, post_temp_dst, ctx, true); } #define CHECK_RSP_ALL_ROWS_NON_ZERO(rsp, func, param) \ diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 14673904798a..cd6457ca48d0 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -241,7 +241,7 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs, } else if (weight_stype == kRowSparseStorage && grad_stype == kDefaultStorage) { NDArray out = outputs[0]; SGDUpdateRspDnsImpl(param, ctx, inputs[0], inputs[1].data(), req[0], &out); - } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage) { + } else { FCompExFallback(attrs, ctx, inputs, req, outputs, SGDUpdate, "SGDUpdate"); } } @@ -592,6 +592,8 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, auto weight_stype = weight.storage_type(); auto grad_stype = grad.storage_type(); auto mom_stype = mom.storage_type(); + CHECK_EQ(weight_stype, mom_stype) << "Inconsistent storage type detected between mom.stype = " + << mom_stype << " and weight.stype = " << weight_stype; if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage && mom_stype == kRowSparseStorage) { NDArray out = outputs[0]; @@ -600,9 +602,10 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, mom_stype == kRowSparseStorage) { NDArray out = outputs[0]; SGDMomUpdateRspDnsImpl(param, ctx, weight, grad.data(), mom, req[0], &out); - } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage && - mom_stype == kDefaultStorage) { - FCompExFallback(attrs, ctx, inputs, req, outputs, SGDMomUpdate, "SGDMomUpdate"); + } else { + // inputs[2] is a mutable input + FCompExFallback(attrs, ctx, inputs, req, outputs, + SGDMomUpdate, "SGDMomUpdate", {2}); } } diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 980fd1956448..05bbe975962f 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -22,8 +22,8 @@ It updates the weights using:: weight = weight - learning_rate * gradient -If weights are stored with `row_sparse` storage, -update is applied only to rows whose gradient has non-zero entries. +If weight is stored with `row_sparse` storage type, +only the row slices whose indices appear in grad.indices are updated. )code" ADD_FILELINE) .set_num_inputs(2) @@ -56,8 +56,8 @@ It updates the weights using:: Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. -If weights are stored with `row_sparse` storage, -only rows whose gradients contain non-zero entries are updated (for both weight and momentum). +If weights are stored with `row_sparse` storage type, +only the row slices whose indices appear in grad.indices are updated (for both weight and momentum). )code" ADD_FILELINE) .set_num_inputs(3) diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h index 2ad1957a4648..f329ab62b1cf 100644 --- a/src/operator/tensor/cast_storage-inl.h +++ b/src/operator/tensor/cast_storage-inl.h @@ -331,6 +331,8 @@ void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 1); CHECK_EQ(outputs.size(), 1); + if (req[0] == kNullOp) return; + CHECK_EQ(req[0], kWriteTo) << "CastStorageComputeEx expects req[0] == kWriteTo"; CastStorageComputeImpl(ctx, inputs[0], outputs[0]); } diff --git a/src/operator/tensor/cast_storage.cc b/src/operator/tensor/cast_storage.cc index f32133171130..f6561061034c 100644 --- a/src/operator/tensor/cast_storage.cc +++ b/src/operator/tensor/cast_storage.cc @@ -11,7 +11,6 @@ namespace mxnet { namespace op { -// TODO(haibin) declare backward op for cast storage DMLC_REGISTER_PARAMETER(CastStorageParam); NNVM_REGISTER_OP(cast_storage) .describe(R"code(Casts tensor storage type to the new type. @@ -28,6 +27,7 @@ NNVM_REGISTER_OP(cast_storage) }) .set_attr("FCompute", IdentityCompute) .set_attr("FComputeEx", CastStorageComputeEx) +.set_attr("FGradient", ElemwiseGradUseNone{"_copy"}) .add_argument("data", "NDArray-or-Symbol", "The input.") .add_arguments(CastStorageParam::__FIELDS__()); diff --git a/src/operator/tensor/cast_storage.cu b/src/operator/tensor/cast_storage.cu index 79f369fb2054..47977a2eac1a 100644 --- a/src/operator/tensor/cast_storage.cu +++ b/src/operator/tensor/cast_storage.cu @@ -10,7 +10,7 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(cast_storage) -.set_attr("FCompute", IdentityCompute) +.set_attr("FCompute", IdentityCompute) .set_attr("FComputeEx", CastStorageComputeEx); } // namespace op diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index 04079f5ef9f3..98f7c4f46728 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -133,6 +133,7 @@ void BinaryComputeRspRspImpl(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + if (req[0] == kNullOp) return; CHECK(req[0] == kWriteTo) << "only kWriteTo is supported for rowsparse elemwise_add"; using namespace rowsparse; using namespace mshadow; @@ -202,10 +203,10 @@ void BinaryComputeRspRspImpl(const nnvm::NodeAttrs& attrs, indices_out[iter_out] = indices_r[iter_r]; Copy(out[iter_out++], data_r[iter_r++], s); } - auto new_ashape = output.aux_shape(rowsparse::kIdx); - CHECK_GT(new_ashape[0], num_common_rows); - new_ashape[0] -= num_common_rows; - output.set_aux_shape(rowsparse::kIdx, new_ashape); + auto new_sshape = TShape(output.aux_shape(rowsparse::kIdx)); + CHECK_GT(new_sshape[0], num_common_rows); + new_sshape[0] -= num_common_rows; + output.set_aux_shape(rowsparse::kIdx, new_sshape); }); }); } diff --git a/src/operator/tensor/elemwise_unary_op.cu b/src/operator/tensor/elemwise_unary_op.cu index 6da7ceff16ac..e1f3ebb5c156 100644 --- a/src/operator/tensor/elemwise_unary_op.cu +++ b/src/operator/tensor/elemwise_unary_op.cu @@ -22,7 +22,8 @@ NNVM_REGISTER_OP(_backward_sigmoid) // copy NNVM_REGISTER_OP(_copy) -.set_attr("FCompute", IdentityCompute); +.set_attr("FCompute", IdentityCompute) +.set_attr("FComputeEx", IdentityComputeEx); NNVM_REGISTER_OP(_backward_copy) .set_attr("FCompute", IdentityCompute); diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 63a4a3ddb795..cb52169b9d33 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -88,8 +88,8 @@ void IdentityComputeRspRspImpl(const nnvm::NodeAttrs& attrs, using namespace mshadow; using namespace mshadow::expr; using namespace rowsparse; - CHECK_NE(req, kNullOp) << "kNullOp in IdentityComputeEx not supported yet"; - CHECK_NE(req, kWriteInplace) << "kWriteInplace in IdentityComputeEx not supported yet"; + if (req == kNullOp) return; + CHECK_EQ(req, kWriteTo) << "kWriteTo is expected for IdentityComputeRspRspImpl"; if (!input.storage_initialized()) { FillZerosRspImpl(s, output); return; @@ -120,6 +120,7 @@ void IdentityComputeEx(const nnvm::NodeAttrs& attrs, const auto in_stype = inputs[0].storage_type(); const auto out_stype = outputs[0].storage_type(); mshadow::Stream *s = ctx.get_stream(); + if (req[0] == kNullOp) return; if (in_stype == out_stype) { if (in_stype == kDefaultStorage) { // dense ndarray IdentityCompute(attrs, ctx, {inputs[0].data()}, req, {outputs[0].data()}); @@ -128,6 +129,7 @@ void IdentityComputeEx(const nnvm::NodeAttrs& attrs, FillComputeZerosEx(attrs, ctx, inputs, req, outputs); return; } + CHECK_NE(req[0], kAddTo) << "kAddTo is not supported for IdentityComputeEx"; const size_t n = mxnet::num_aux_data(out_stype); outputs[0].CheckAndAlloc(inputs[0].aux_shapes()); IdentityCompute(attrs, ctx, {inputs[0].data()}, req, {outputs[0].data()}); @@ -142,6 +144,21 @@ void IdentityComputeEx(const nnvm::NodeAttrs& attrs, } } +inline bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + // TODO(junwu): add ctx info into storage inference logic + CHECK_EQ(in_attrs->size(), static_cast(2)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(1)) << " in operator " << attrs.name; + auto &in = *in_attrs; + auto &out = *out_attrs; + CHECK_NE(in[1], kUndefinedStorage) << "rhs storage type must be known"; + if (in[0] == kUndefinedStorage) STORAGE_TYPE_ASSIGN_CHECK(in, 0, in[1]); + if (out[0] == kUndefinedStorage) STORAGE_TYPE_ASSIGN_CHECK(out, 0, in[1]); + return true; +} + template void IdentityLikeRhsComputeEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -155,12 +172,12 @@ void IdentityLikeRhsComputeEx(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); const auto in_stype = inputs[0].storage_type(); const auto out_stype = outputs[0].storage_type(); - // row_sparse -> row_sparse - if (in_stype == kRowSparseStorage && out_stype == kRowSparseStorage) { - NDArray out = outputs[0]; - IdentityComputeRspRspImpl(attrs, s, inputs[0], req[0], &out); + if (in_stype == out_stype) { + std::vector in{inputs[0]}; + IdentityComputeEx(attrs, ctx, in, req, outputs); } else { - LOG(FATAL) << "Not implemented yet"; + LOG(FATAL) << "IdentityLikeRhsComputeEx not implemented for in_stype = " << in_stype + << " out_stype = " << out_stype; } } diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 65f4e1001c07..64f1327196ff 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -169,7 +169,6 @@ void FillZerosCsrImpl(mshadow::Stream *s, NDArray *dst) { dst->set_aux_shape(csr::kIdx, new_shape); } -// This operator never needs to fall back, since there's no input NDArray template void FillComputeZerosEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -180,8 +179,9 @@ void FillComputeZerosEx(const nnvm::NodeAttrs& attrs, using namespace mshadow::expr; Stream *s = ctx.get_stream(); CHECK_EQ(outputs.size(), 1); - CHECK_EQ(inputs.size(), 0); auto stype = outputs[0].storage_type(); + if (req[0] == kNullOp) return; + CHECK_EQ(req[0], kWriteTo) << "kWriteTo is expected for FillComputeZerosEx"; if (stype == kRowSparseStorage) { NDArray nd(outputs[0]); FillZerosRspImpl(s, &nd); @@ -189,7 +189,8 @@ void FillComputeZerosEx(const nnvm::NodeAttrs& attrs, NDArray nd(outputs[0]); FillZerosCsrImpl(s, &nd); } else { - LOG(FATAL) << "storage type not implemented."; + // no fallback is required since the output doesn't depend on input + LOG(FATAL) << "storage type " << stype << " not implemented."; } } diff --git a/src/operator/tensor/sparse_retain-inl.h b/src/operator/tensor/sparse_retain-inl.h index 04c81ee881aa..fb1b2512bf72 100644 --- a/src/operator/tensor/sparse_retain-inl.h +++ b/src/operator/tensor/sparse_retain-inl.h @@ -238,6 +238,7 @@ void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); + if (req[sr::kOut] == kNullOp) return; CHECK_EQ(req[sr::kOut], kWriteTo) << "sparse_retain only supports req=\'write\'"; CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage) << "sparse_retain operator only takes default NDArray as its index array"; diff --git a/src/operator/tensor/square_sum.cc b/src/operator/tensor/square_sum.cc index 14c11ce82419..ae01d0f71c4a 100644 --- a/src/operator/tensor/square_sum.cc +++ b/src/operator/tensor/square_sum.cc @@ -17,7 +17,7 @@ in the future. Example:: dns = mx.nd.array([[0, 0], [1, 2], [0, 0], [3, 4], [0, 0]]) - rsp = mx.nd.cast_storage(dns, stype='row_sparse') + rsp = dns.tostype('row_sparse') sum = mx.nd._internal._square_sum(rsp, axis=1) sum = [0, 5, 0, 25, 0] )code" ADD_FILELINE) diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py index f88b412b027c..df52fec87f8e 100644 --- a/tests/nightly/dist_sync_kvstore.py +++ b/tests/nightly/dist_sync_kvstore.py @@ -26,8 +26,8 @@ def init_kv(): kv.init(keys, [mx.nd.ones(shape)] * len(keys)) kv.init('99', mx.nd.ones(big_shape)) # init kv row_sparse keys - kv.init(rsp_keys, [mx.nd.ones(shape)._to_rsp()] * len(rsp_keys)) - kv.init('100', mx.nd.ones(big_shape)._to_rsp()) + kv.init(rsp_keys, [mx.nd.ones(shape).tostype('row_sparse')] * len(rsp_keys)) + kv.init('100', mx.nd.ones(big_shape).tostype('row_sparse')) # worker info my_rank = kv.rank nworker = kv.num_workers @@ -60,7 +60,7 @@ def check_row_sparse_keys(kv, my_rank, nworker): v[my_row] = my_rank + 1 # push for i in range(nrepeat): - kv.push('9', v._to_rsp()) + kv.push('9', v.tostype('row_sparse')) # select a random subset of rows this worker is interested in num_rows = shape[0] row_ids_np = np.random.randint(num_rows, size=num_rows) @@ -86,13 +86,13 @@ def check_row_sparse_keys_with_zeros(kv, my_rank, nworker): big_v = mx.nd.zeros(big_shape) # push for i in range(nrepeat): - kv.push('11', v._to_rsp()) - kv.push('100', big_v._to_rsp()) + kv.push('11', v.tostype('row_sparse')) + kv.push('100', big_v.tostype('row_sparse')) # pull a subset of rows this worker is interested in all_row_ids = np.arange(shape[0]) - val = mx.nd.ones(shape)._to_rsp() - big_val = mx.nd.ones(big_shape)._to_rsp() + val = mx.nd.ones(shape).tostype('row_sparse') + big_val = mx.nd.ones(big_shape).tostype('row_sparse') kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array(all_row_ids, dtype='int64')) big_num_rows = shape[0] big_all_row_ids = np.arange(big_shape[0]) @@ -125,7 +125,7 @@ def check_big_row_sparse_keys(kv, my_rank, nworker): v[row] = my_rank + 1 # push for i in range(nrepeat): - kv.push('100', v._to_rsp()) + kv.push('100', v.tostype('row_sparse')) # select a random subset of rows this worker is interested in mx.random.seed(my_rank) diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index 17519a6d1b06..974ee9571599 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py @@ -102,8 +102,7 @@ def check_unary_func(x): uniform = nd.uniform(shape=(4, 5)) stypes = ['row_sparse', 'csr', 'default'] for stype in stypes: - x = mx.nd.cast_storage(uniform, stype=stype) - check_unary_func(x) + check_unary_func(uniform.tostype(stype)) def test_binary_func(): def check_binary_func(x, y): @@ -121,8 +120,8 @@ def check_binary_func(x, y): stypes = ['row_sparse', 'csr', 'default'] for stype_x in stypes: for stype_y in stypes: - x = mx.nd.cast_storage(uniform_x, stype=stype_x) - y = mx.nd.cast_storage(uniform_y, stype=stype_y) + x = uniform_x.tostype(stype_x) + y = uniform_y.tostype(stype_y) check_binary_func(x, y) @@ -262,7 +261,7 @@ def check_attach_grad(x): zeros = mx.nd.zeros((10, 10)) stypes = ['default', 'row_sparse', 'csr'] for stype in stypes: - x = mx.nd.cast_storage(zeros, stype=stype) + x = zeros.tostype(stype) check_attach_grad(x) diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index 53a26d03e12b..ed09d59d4c15 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -43,7 +43,7 @@ def check_single_kv_pair(kv, key): def test_row_sparse_pull(): kv = init_kv_with_str('row_sparse') - kv.init('e', mx.nd.ones(shape)._to_rsp()) + kv.init('e', mx.nd.ones(shape).tostype('row_sparse')) def check_row_sparse_pull(kv, count): num_rows = shape[0] @@ -51,7 +51,7 @@ def check_row_sparse_pull(kv, count): row_ids = [] all_row_ids = np.arange(num_rows) for i in range(count): - vals.append(mx.nd.zeros(shape)._to_rsp()) + vals.append(mx.nd.zeros(shape).tostype('row_sparse')) row_id = np.random.randint(num_rows, size=num_rows) row_ids.append(mx.nd.array(row_id, dtype='int64')) row_ids_to_pull = row_ids[0] if len(row_ids) == 1 else row_ids diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 92ce08d3acda..98d6d69a0aac 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -487,8 +487,8 @@ def fm(factor_size, feature_dim, init): import scipy.sparse as sp # generate some random scipy csr data csr_sp = sp.rand(num_samples, feature_dim, density=0.1, format='csr') - csr_nd = mx.nd.csr(csr_sp.data, csr_sp.indptr, csr_sp.indices, - (num_samples, feature_dim)) + csr_nd = mx.nd.csr_matrix(csr_sp.data, csr_sp.indptr, csr_sp.indices, + (num_samples, feature_dim)) label = mx.nd.ones((num_samples,1)) # the alternative is to use LibSVMIter train_iter = mx.io.NDArrayIter(data=csr_nd, diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index b6c902b13370..553bf39ba028 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -34,17 +34,17 @@ def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='defa if w_stype == 'default': w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) w1 = w2.copyto(default_context()) - elif w_stype == 'row_sparse': - w2 = rand_ndarray(shape, w_stype, density=1) - w1 = w2.copyto(default_context()).todense() + elif w_stype == 'row_sparse' or w_stype == 'csr': + w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype) + w1 = w2.copyto(default_context()).tostype('default') else: raise Exception("type not supported yet") if g_stype == 'default': g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) g1 = g2.copyto(default_context()) - elif g_stype == 'row_sparse': - g2 = rand_ndarray(shape, g_stype) - g1 = g2.copyto(default_context()).todense() + elif g_stype == 'row_sparse' or g_stype == 'csr': + g2 = rand_ndarray(shape, g_stype, dtype=dtype) + g1 = g2.copyto(default_context()).tostype('default') else: raise Exception("type not supported yet") @@ -186,6 +186,13 @@ def test_sgd(): not kwarg['multi_precision'])): continue compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) + # test operator fallback on cpu + if (default_context() == mx.cpu()): + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, + g_stype='row_sparse') + if dtype != np.float16: + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape[:2], + dtype, w_stype='csr', g_stype='csr') class PySparseSGD(mx.optimizer.Optimizer): """python reference implemenation of sgd""" @@ -260,24 +267,28 @@ def test_sparse_sgd(): mx.random.seed(0) opt1 = PySparseSGD opt2 = mx.optimizer.SGD - shape = (3, 4) - kwargs = [{}, - {'momentum': 0.9}, - {'clip_gradient': 0.5}, - {'clip_gradient': 0.4, 'rescale_grad': 0.14}, - {'rescale_grad': 0.8}, - {'clip_gradient': 0.5, 'wd': 0.07}, - {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'wd': 0.03}, - {'rescale_grad': 0.8, 'wd': 0.05}, - {'clip_gradient': 0.5, 'momentum': 0.9}, - {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'momentum': 0.9}, - {'rescale_grad': 0.8, 'momentum': 0.9}, - {'clip_gradient': 0.5, 'wd': 0.07, 'momentum': 0.9}, - {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'wd': 0.03, 'momentum': 0.9}, - {'rescale_grad': 0.8, 'wd': 0.05, 'momentum': 0.9}] - for kwarg in kwargs: - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, 'float32', w_stype='row_sparse', g_stype='row_sparse') - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, 'float32', w_stype='row_sparse', g_stype='default') + shape = (3, 4, 5) + mom_options = [{}, {'momentum': 0.9}] + cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}] + rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] + wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] + mp_options = [{}] + for dtype in [np.float32]: + for mom_option in mom_options: + for cg_option in cg_options: + for rg_option in rg_options: + for wd_option in wd_options: + for mp_option in mp_options: + kwarg = {} + kwarg.update(mom_option) + kwarg.update(cg_option) + kwarg.update(rg_option) + kwarg.update(wd_option) + kwarg.update(mp_option) + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, + w_stype='row_sparse', g_stype='row_sparse') + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, + w_stype='row_sparse', g_stype='default') # ADAM @@ -358,7 +369,10 @@ def test_adam(): {'rescale_grad': 0.8, 'wd': 0.05}] for kwarg in kwargs: compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32) - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32, g_stype='row_sparse') + # test operator fallback on cpu + if (default_context() == mx.cpu()): + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, + np.float32, g_stype='row_sparse') # RMSProp class PyRMSProp(mx.optimizer.Optimizer): diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 06ab437df226..ca38129a542b 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -15,7 +15,7 @@ def assert_fcompex(f, *args, **kwargs): def sparse_nd_ones(shape, stype): - return mx.nd.cast_storage(mx.nd.ones(shape), stype=stype) + return mx.nd.ones(shape).tostype(stype) def check_sparse_nd_elemwise_binary(shapes, stypes, f, g): @@ -192,7 +192,7 @@ def test_sparse_nd_lesser_equal(): def test_sparse_nd_binary(): N = 10 - def check_binary(fn): + def check_binary(fn, stype): for _ in range(N): ndim = 2 oshape = np.random.randint(1, 6, size=(ndim,)) @@ -207,54 +207,47 @@ def check_binary(fn): rshape[bdim-i-1] = 1 lhs = np.random.uniform(0, 1, size=lshape) rhs = np.random.uniform(0, 1, size=rshape) - lhs_nd_csr = mx.nd.array(lhs)._to_csr() - rhs_nd_csr = mx.nd.array(rhs)._to_csr() - lhs_nd_rsp = mx.nd.array(lhs)._to_rsp() - rhs_nd_rsp = mx.nd.array(rhs)._to_rsp() - for lhs_nd, rhs_nd in [(lhs_nd_csr, rhs_nd_csr), (lhs_nd_rsp, rhs_nd_rsp)]: - assert_allclose(fn(lhs, rhs), - fn(lhs_nd, rhs_nd).asnumpy(), - rtol=1e-4, atol=1e-4) - - check_binary(lambda x, y: x + y) - check_binary(lambda x, y: x - y) - check_binary(lambda x, y: x * y) - check_binary(lambda x, y: x / y) - check_binary(lambda x, y: x ** y) - check_binary(lambda x, y: x > y) - check_binary(lambda x, y: x < y) - check_binary(lambda x, y: x >= y) - check_binary(lambda x, y: x <= y) - check_binary(lambda x, y: x == y) + lhs_nd = mx.nd.array(lhs).tostype(stype) + rhs_nd = mx.nd.array(rhs).tostype(stype) + assert_allclose(fn(lhs, rhs), fn(lhs_nd, rhs_nd).asnumpy(), rtol=1e-4, atol=1e-4) + + stypes = ['row_sparse', 'csr'] + for stype in stypes: + check_binary(lambda x, y: x + y, stype) + check_binary(lambda x, y: x - y, stype) + check_binary(lambda x, y: x * y, stype) + check_binary(lambda x, y: x / y, stype) + check_binary(lambda x, y: x ** y, stype) + check_binary(lambda x, y: x > y, stype) + check_binary(lambda x, y: x < y, stype) + check_binary(lambda x, y: x >= y, stype) + check_binary(lambda x, y: x <= y, stype) + check_binary(lambda x, y: x == y, stype) def test_sparse_nd_binary_rop(): N = 10 - def check(fn): + def check(fn, stype): for _ in range(N): ndim = 2 shape = np.random.randint(1, 6, size=(ndim,)) - npy_nd = np.random.normal(0, 1, size=shape) - csr_nd = mx.nd.array(npy_nd)._to_csr() - rsp_nd = mx.nd.array(npy_nd)._to_rsp() - for sparse_nd in [csr_nd, rsp_nd]: - assert_allclose( - fn(npy_nd), - fn(sparse_nd).asnumpy(), - rtol=1e-4, - atol=1e-4 - ) - check(lambda x: 1 + x) - check(lambda x: 1 - x) - check(lambda x: 1 * x) - check(lambda x: 1 / x) - check(lambda x: 2 ** x) - check(lambda x: 1 > x) - check(lambda x: 0.5 > x) - check(lambda x: 0.5 < x) - check(lambda x: 0.5 >= x) - check(lambda x: 0.5 <= x) - check(lambda x: 0.5 == x) + npy = np.random.normal(0, 1, size=shape) + nd = mx.nd.array(npy).tostype(stype) + assert_allclose(fn(npy), fn(nd).asnumpy(), rtol=1e-4, atol=1e-4) + + stypes = ['row_sparse', 'csr'] + for stype in stypes: + check(lambda x: 1 + x, stype) + check(lambda x: 1 - x, stype) + check(lambda x: 1 * x, stype) + check(lambda x: 1 / x, stype) + check(lambda x: 2 ** x, stype) + check(lambda x: 1 > x, stype) + check(lambda x: 0.5 > x, stype) + check(lambda x: 0.5 < x, stype) + check(lambda x: 0.5 >= x, stype) + check(lambda x: 0.5 <= x, stype) + check(lambda x: 0.5 == x, stype) def test_sparse_nd_binary_iop(): N = 10 @@ -266,8 +259,8 @@ def check_binary(fn, stype): rshape = list(oshape) lhs = np.random.uniform(0, 1, size=lshape) rhs = np.random.uniform(0, 1, size=rshape) - lhs_nd = mx.nd.cast_storage(mx.nd.array(lhs), stype=stype) - rhs_nd = mx.nd.cast_storage(mx.nd.array(rhs), stype=stype) + lhs_nd = mx.nd.array(lhs).tostype(stype) + rhs_nd = mx.nd.array(rhs).tostype(stype) assert_allclose(fn(lhs, rhs), fn(lhs_nd, rhs_nd).asnumpy(), rtol=1e-4, atol=1e-4) @@ -285,10 +278,9 @@ def inplace_mul(x, y): check_binary(fn, stype) def test_sparse_nd_negate(): - npy = np.random.uniform(-10, 10, rand_shape_2d()) - arr_csr = mx.nd.array(npy)._to_csr() - arr_rsp = mx.nd.array(npy)._to_rsp() - for arr in [arr_csr, arr_rsp]: + def check_sparse_nd_negate(shape, stype): + npy = np.random.uniform(-10, 10, rand_shape_2d()) + arr = mx.nd.array(npy).tostype(stype) assert_almost_equal(npy, arr.asnumpy()) assert_almost_equal(-npy, (-arr).asnumpy()) @@ -297,6 +289,11 @@ def test_sparse_nd_negate(): # we compute (-arr) assert_almost_equal(npy, arr.asnumpy()) + shape = rand_shape_2d() + stypes = ['csr', 'row_sparse'] + for stype in stypes: + check_sparse_nd_negate(shape, stype) + def test_sparse_nd_broadcast(): sample_num = 1000 # TODO(haibin) test with more than 2 dimensions @@ -312,7 +309,7 @@ def test_broadcast_to(stype): shape[axis] = 1 dat = np.random.rand(*shape) - 0.5 numpy_ret = dat - ndarray = mx.nd.cast_storage(mx.nd.array(dat), stype=stype) + ndarray = mx.nd.array(dat).tostype(stype) ndarray_ret = ndarray.broadcast_to(shape=target_shape) if type(ndarray_ret) is mx.ndarray.NDArray: ndarray_ret = ndarray_ret.asnumpy() @@ -328,7 +325,7 @@ def test_sparse_nd_transpose(): npy = np.random.uniform(-10, 10, rand_shape_2d()) stypes = ['csr', 'row_sparse'] for stype in stypes: - nd = mx.nd.cast_storage(mx.nd.array(npy), stype=stype) + nd = mx.nd.array(npy).tostype(stype) assert_almost_equal(npy.T, (nd.T).asnumpy()) def test_sparse_nd_output_fallback(): @@ -421,7 +418,7 @@ def test_create_csr(): data = matrix.data indptr = matrix.indptr indices = matrix.indices - csr_created = mx.nd.csr(data=data, indptr=indptr, indices=indices, shape=shape) + csr_created = mx.nd.csr_matrix(data=data, indptr=indptr, indices=indices, shape=shape) assert csr_created.stype == 'csr' assert same(csr_created.data.asnumpy(), data.asnumpy()) assert same(csr_created.indptr.asnumpy(), indptr.asnumpy()) @@ -439,7 +436,7 @@ def test_create_row_sparse(): matrix = rand_ndarray(shape, 'row_sparse', density) data = matrix.data indices = matrix.indices - rsp_created = mx.nd.row_sparse(data=data, indices=indices, shape=shape) + rsp_created = mx.nd.row_sparse_array(data=data, indices=indices, shape=shape) assert rsp_created.stype == 'row_sparse' assert same(rsp_created.data.asnumpy(), data.asnumpy()) assert same(rsp_created.indices.asnumpy(), indices.asnumpy()) diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index dea5b99f05b0..680255cce4ff 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -45,8 +45,8 @@ def test_elemwise_add_ex_multiple_stages(): val2 = mx.nd.array([[5, 10]]); idx1 = mx.nd.array([0], dtype=np.int64); idx2 = mx.nd.array([1], dtype=np.int64); - sp_nd1 = mx.nd.row_sparse(val1, idx1, shape) - sp_nd2 = mx.nd.row_sparse(val2, idx2, shape) + sp_nd1 = mx.nd.row_sparse_array(val1, idx1, shape) + sp_nd2 = mx.nd.row_sparse_array(val2, idx2, shape) ds_nd = mx.nd.array(ds_np) # sparse + sparse = sparse @@ -67,59 +67,60 @@ def test_elemwise_add_ex_multiple_stages(): exec_test.backward(out_grads=exec_test.outputs) assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy()) -# TODO(haibin) also add test for backward pass. def test_cast_storage_ex(): - def test_rsp_to_dns(shape, density): - rsp_in, (data, row_idx) = rand_sparse_ndarray(shape, 'row_sparse', density) - dns_out = mx.nd.cast_storage(rsp_in, stype='default') - assert same(rsp_in.asnumpy(), dns_out.asnumpy()) - - def test_dns_to_rsp(shape, density): - rsp_in, (data, row_idx) = rand_sparse_ndarray(shape, 'row_sparse', density) - rsp_out = mx.nd.cast_storage(mx.nd.array(rsp_in.todense(), dtype=default_dtype()), stype='row_sparse') - assert same(rsp_in.asnumpy(), rsp_out.asnumpy()) - - def test_csr_to_dns(shape, density): - csr_in, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr', density) - dns_out = mx.nd.cast_storage(csr_in, stype='default') - assert same(csr_in.asnumpy(), dns_out.asnumpy()) - - def test_dns_to_csr(shape, density): - csr_in, (indptr, colidx, data) = rand_sparse_ndarray(shape, 'csr', density) - csr_out = mx.nd.cast_storage(mx.nd.array(csr_in.todense(), dtype=default_dtype()), stype='csr') - assert same(csr_in.asnumpy(), csr_out.asnumpy()) + def check_cast_storage(shape, density, from_stype, to_stype, check_numeric_grad=True): + x = mx.symbol.Variable('x', stype=from_stype) + x_nd = rand_ndarray(shape, from_stype, density=density) + x_np = x_nd.asnumpy() + out_np = x_np + test = mx.symbol.cast_storage(x, stype=to_stype) + location = {'x': x_nd} + check_symbolic_forward(test, location, [out_np]) + # consider disable the numeric grad check for gpu block kernel since the input is large + if check_numeric_grad: + check_numeric_gradient(test, location) + grad_stypes = {'x': to_stype} + check_symbolic_backward(test, location, [out_np], [out_np], grad_stypes=grad_stypes) density = [1.00, 0.50, 0.10, 0.05, 0.01] for d in density: shape_2d = rand_shape_2d() shape_3d = rand_shape_3d() - test_csr_to_dns(shape_2d, d) - test_dns_to_csr(shape_2d, d) - test_rsp_to_dns(shape_2d, d) - test_dns_to_rsp(shape_2d, d) - test_rsp_to_dns(shape_3d, d) - test_dns_to_rsp(shape_3d, d) + check_cast_storage(shape_2d, d, 'csr', 'default') + check_cast_storage(shape_2d, d, 'default', 'csr') + check_cast_storage(shape_2d, d, 'row_sparse', 'default') + check_cast_storage(shape_2d, d, 'default', 'row_sparse') + check_cast_storage(shape_3d, d, 'row_sparse', 'default') + check_cast_storage(shape_3d, d, 'default', 'row_sparse') for i in range(4, 6): shape = rand_shape_nd(i, 5) - test_dns_to_rsp(shape, d) - test_rsp_to_dns(shape, d) + check_cast_storage(shape, d, 'default', 'row_sparse') + check_cast_storage(shape, d, 'row_sparse', 'default') # Test specific gpu kernels if default_context().device_type is 'gpu': - test_dns_to_csr((rnd.randint(1, 10), rnd.randint( 1, 32)), d) # test gpu thread kernel - test_dns_to_csr((rnd.randint(1, 10), rnd.randint( 32, 512)), d) # test gpu warp kernel - test_dns_to_csr((rnd.randint(1, 10), rnd.randint(512, 1024)), d) # test gpu block kernel - test_dns_to_rsp((rnd.randint(1, 10), rnd.randint( 1, 32)), d) # test gpu thread kernel - test_dns_to_rsp((rnd.randint(1, 10), rnd.randint( 32, 512)), d) # test gpu warp kernel - test_dns_to_rsp((rnd.randint(1, 10), rnd.randint(512, 1024)), d) # test gpu block kernel - + dim0 = rnd.randint(1, 10) + # test gpu thread kernel + check_cast_storage((dim0, rnd.randint( 1, 32)), d, 'default', 'csr') + # test gpu warp kernel + check_cast_storage((dim0, rnd.randint( 32, 512)), d, 'default', 'csr') + # test gpu block kernel + check_cast_storage((dim0, rnd.randint(512, 1024)), d, 'default', 'csr', + check_numeric_grad=False) + # test gpu thread kernel + check_cast_storage((dim0, rnd.randint( 1, 32)), d, 'default', 'row_sparse') + # test gpu warp kernel + check_cast_storage((dim0, rnd.randint( 32, 512)), d, 'default', 'row_sparse') + # test gpu block kernel + check_cast_storage((dim0, rnd.randint(512, 1024)), d, 'default', 'row_sparse', + check_numeric_grad=False) def test_sparse_dot(): def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1): lhs_nd = rand_ndarray(lhs_shape, 'csr', 1) - lhs_dns = lhs_nd.todense() + lhs_dns = lhs_nd.tostype('default') rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=density) - rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense() + rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.tostype('default') out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs) if trans_lhs and default_context().device_type is 'cpu': assert out.stype == 'row_sparse' @@ -221,7 +222,7 @@ def test_sparse_square_sum(): for density in densities: shape = rand_shape_2d(dim0, dim1) rsp = rand_ndarray(shape, 'row_sparse', density) - dns = rsp.todense() + dns = rsp.tostype('default') for axis in axes: for keepdim in keepdims: ret = mx.nd._internal._square_sum(rsp, axis=axis, keepdims=keepdim)