Skip to content

Commit

Permalink
code changes according to review comments
Browse files Browse the repository at this point in the history
remove executor debug. add doc to optimizer

update sparse sgd test

add dtype option to rand_sparse_ndarray
  • Loading branch information
eric-haibin-lin committed Aug 5, 2017
1 parent d511938 commit 57da011
Show file tree
Hide file tree
Showing 18 changed files with 149 additions and 89 deletions.
24 changes: 24 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,30 @@ MXNET_DLL int MXCreateCachedOp(SymbolHandle handle,
* \brief free cached operator
*/
MXNET_DLL int MXFreeCachedOp(CachedOpHandle handle);
/*!
* \brief invoke cached operator
*/
MXNET_DLL int MXInvokeCachedOp(CachedOpHandle handle,
int num_inputs,
NDArrayHandle *inputs,
int *num_outputs,
NDArrayHandle **outputs);
/*!
* \brief invoke a cached op
* \param handle the handle to the cached op
* \param num_inputs number of input NDArrays
* \param inputs input NDArrays
* \param num_outputs number of output NDArrays
* \param outputs output NDArrays
* \param out_stypes output ndarrays' stypes
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
int num_inputs,
NDArrayHandle *inputs,
int *num_outputs,
NDArrayHandle **outputs,
const int** out_stypes);
/*!
* \brief invoke cached operator
*/
Expand Down
1 change: 0 additions & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,6 @@ size_t num_aux_data(NDArrayStorageType stype);
* \param from the ndarray we want to copy data from
* \param to the target ndarray
* \param priority Priority of the action.
* \param alloc_output whether to allocate memory for the output ndarray
* \note The function name explicitly marks the order of from and to
* due to different possible convention carried by copy function.
*/
Expand Down
15 changes: 11 additions & 4 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,24 @@ def __call__(self, *args, **kwargs):
"CachedOp.__call__ got unexpected keyword argument(s): " + \
', '.join(kwargs.keys()))

check_call(_LIB.MXInvokeCachedOp(
# return output stypes to avoid the c_api call for checking
# a handle's stype in _ndarray_cls
out_stypes = ctypes.POINTER(ctypes.c_int)()

check_call(_LIB.MXInvokeCachedOpEx(
self.handle,
ctypes.c_int(len(args)),
c_array(NDArrayHandle, [arr.handle for arr in args]),
ctypes.byref(num_output),
ctypes.byref(output_vars)))
ctypes.byref(output_vars),
ctypes.byref(out_stypes)))

if original_output is not None:
return original_output
if num_output.value == 1:
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle))
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]])
else:
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle))
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]])
for i in range(num_output.value)]
3 changes: 2 additions & 1 deletion python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,12 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
row_ids : NDArray or list of NDArray
The row_ids for which to pull for each value. The row_ids doesn't have to be unique
or sorted.
Examples
--------
>>> shape = (3, 3)
>>> kv.init('3', mx.nd.ones(shape)._to_rsp())
>>> a = mx.nd.zeros(shape)
>>> a = mx.nd.zeros(shape, stype='row_sparse')
>>> row_ids = mx.nd.array([0, 2], dtype='int64')
>>> kv.row_sparse_pull('3', out=a, row_ids=row_ids)
>>> print a.asnumpy()
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from .ndarray import onehot_encode, power, subtract, true_divide, waitall, _new_empty_handle
from .ndarray_utils import load, save, zeros, empty, array
from .sparse_ndarray import _ndarray_cls, todense
from .sparse_ndarray import csr, row_sparse, BaseSparseNDArray, RowSparseNDArray, CSRNDArray
from .sparse_ndarray import csr_matrix, row_sparse_array, BaseSparseNDArray, RowSparseNDArray, CSRNDArray
56 changes: 35 additions & 21 deletions python/mxnet/ndarray/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __setitem__(self, key, value):
Examples
--------
>>> src = mx.nd.row_sparse([[1, 0, 2], [4, 5, 6]], [0, 2], (3,3))
>>> src = mx.nd.row_sparse_array([[1, 0, 2], [4, 5, 6]], [0, 2], (3,3))
>>> src.asnumpy()
array([[ 1., 0., 2.],
[ 0., 0., 0.],
Expand Down Expand Up @@ -325,7 +325,9 @@ def _aux_data(self, i):

# pylint: disable=abstract-method
class CSRNDArray(BaseSparseNDArray):
"""A CSRNDArray represents a NDArray as three separate arrays: `data`,
"""A sparse representation of tensor in standard CSR format.
A CSRNDArray represents an NDArray as three separate arrays: `data`,
`indptr` and `indices`. It uses the standard CSR representation where the column indices for
row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored
in values[indptr[i]:indptr[i+1]].
Expand Down Expand Up @@ -391,26 +393,34 @@ def indptr(self):

# pylint: disable=abstract-method
class RowSparseNDArray(BaseSparseNDArray):
"""A RowSparseNDArray is typically used to represent a subset of a larger
NDArray with `default` of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0. The values
in indices are the indices in the first dimension of the slices that have been extracted from
the larger NDArray. The indices are expected to be sorted in ascending order.
"""A sparse representation of a set of tensor slices at given indices.
The corresponding NDArray ``dense`` with `default` storage represented by a ``rsp``
RowSparseNDArray
A RowSparseNDArray represents an K-dimensional NDArray as two separate arrays: `data` and
`indices`. The `indices` stores the indices of the K-dimensional `data` slices extracted
from the dense NDArray in the first dimension. The corresponding NDArray ``dense``
represented by RowSparseNDArray ``rsp`` has
``dense[rsp.indices[i], :, :, :, ...] = rsp.data[i, :, :, :, ...]``,
``dense[rsp.indices[i], :, :, :, ...] = rsp.values[i, :, :, :, ...]``
where `indices` is an 1-D integer NDArray with shape [D0], and `data` is an NDArray of any
dtype with shape [D0, D1, .., DK]. If the index of a slice in the first dimension
doesn't appear in `indices`, its values are zeros.
A RowSparseNDArray is typically used to represent a subset of a larger dense NDArray of
shape [LARGE0, D1, .. , DK] where LARGE0 >> D0 and most row slices are zeros.
The indices are expected to be sorted in ascending order.
RowSparseNDArray is used principally in the definition of gradients for operations
that have sparse gradients (e.g. dot with sparse inputs).
Examples
--------
>>> import mxnet as mx
>>> dense = mx.nd.array([[1,2],[0,0],[3,0],[0,0]])
>>> dense = mx.nd.array([[1,2],[0,0],[3,0],[0,0],[0,0],[0,0]])
>>> rsp = dense._to_rsp()
>>> rsp.indices.asnumpy()
array([0, 2], dtype=int32)
array([0, 2], dtype=int64)
>>> rsp.data.asnumpy()
array([[ 1., 2.],
[ 3., 0.]], dtype=float32)
Expand Down Expand Up @@ -452,6 +462,9 @@ def indices(self):


def _prepare_src_array(src, dtype, default_dtype):
"""Prepare `src` and its dtype so that they can be used to construct NDArray.
`src` is converted to a `np.ndarray` if it's neither an `NDArray` nor an `np.ndarray`.
"""
if isinstance(src, NDArray):
dtype = src.dtype if dtype is None else dtype
else:
Expand All @@ -464,8 +477,9 @@ def _prepare_src_array(src, dtype, default_dtype):
return src, dtype


def csr(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, indices_type=None):
"""Creates a 2D array with compressed sparse row format.
def csr_matrix(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None,
indices_type=None):
"""Creates a 2D array with compressed sparse row(CSR) format.
Parameters
----------
Expand All @@ -484,10 +498,10 @@ def csr(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, in
if `values` is an `NDArray`, `float32` otherwise.
indptr_type: str or numpy.dtype, optional
The data type of the indices array. The default dtype is ``indptr.dtype``
if `indptr` is an `NDArray`, `int32` otherwise.
if `indptr` is an `NDArray`, `int64` otherwise.
indices_type: str or numpy.dtype, optional
The data type of the indices array. The default dtype is ``indices.dtype``
if `indicies` is an `NDArray`, `int32` otherwise.
if `indicies` is an `NDArray`, `int64` otherwise.
Returns
-------
Expand All @@ -497,7 +511,7 @@ def csr(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, in
Example
-------
>>> import mxnet as mx
>>> a = mx.nd.csr([1, 2, 3], [0, 1, 2, 2, 3], [1, 0, 2], (4, 3))
>>> a = mx.nd.csr_matrix([1, 2, 3], [0, 1, 2, 2, 3], [1, 0, 2], (4, 3))
>>> a.asnumpy()
array([[ 0., 1., 0.],
[ 2., 0., 0.],
Expand Down Expand Up @@ -540,13 +554,13 @@ def csr(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, in
return result


def row_sparse(data, indices, shape, ctx=None, dtype=None, indices_type=None):
"""Creates a row sparse array with a set of tensor slices at given indices.
def row_sparse_array(data, indices, shape, ctx=None, dtype=None, indices_type=None):
"""Creates a K-dimensional row sparse array with a set of tensor slices at given indices.
Parameters
----------
data: array_like
An object exposing the array interface, with shape [D0, D1, .. Dn], where D0 is
An object exposing the array interface, with shape [D0, D1, .. DK], where D0 is
the number of rows with non-zeros entries.
indices: array_like
An object exposing the array interface, with shape [D0].
Expand All @@ -557,7 +571,7 @@ def row_sparse(data, indices, shape, ctx=None, dtype=None, indices_type=None):
if `data` is an `NDArray`, `float32` otherwise.
indices_type: str or numpy.dtype, optional
The data type of the indices array. The default dtype is ``indices.dtype``
if `indicies` is an `NDArray`, `int32` otherwise.
if `indicies` is an `NDArray`, `int64` otherwise.
Returns
-------
Expand All @@ -566,7 +580,7 @@ def row_sparse(data, indices, shape, ctx=None, dtype=None, indices_type=None):
Example
-------
>>> a = mx.nd.row_sparse([[1, 2], [3, 4]], [1, 4], (6, 2))
>>> a = mx.nd.row_sparse_array([[1, 2], [3, 4]], [1, 4], (6, 2))
>>> a.asnumpy()
array([[ 0., 0.],
[ 1., 2.],
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ class SGD(Optimizer):
state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
weight = weight - state
For details of the update algorithm see :class:`~mxnet.ndarray.sgd_update` and
:class:`~mxnet.ndarray.sgd_mom_update`.
Sparse updating is supported. For details of the update algorithm see
:class:`~mxnet.ndarray.sgd_update` and :class:`~mxnet.ndarray.sgd_mom_update`.
This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.
Expand Down
16 changes: 8 additions & 8 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,34 +79,34 @@ def random_sample(population, k):
return population_copy[0:k]


def rand_sparse_ndarray(shape, stype, density=None):
def rand_sparse_ndarray(shape, stype, density=None, dtype=None):
"""Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np) """
density = rnd.rand() if density is None else density
if stype == 'row_sparse':
# sample index
idx_sample = rnd.rand(shape[0])
indices = np.argwhere(idx_sample < density).flatten()
if indices.shape[0] == 0:
result = mx.nd.zeros(shape, stype='row_sparse')
result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype)
return result, (np.array([], dtype='int64'), np.array([], dtype='int64'))
# generate random values
val = rnd.rand(indices.shape[0], *shape[1:])
arr = mx.nd.row_sparse(val, indices, shape, indices_type=np.int64)
val = rnd.rand(indices.shape[0], *shape[1:]).astype(dtype)
arr = mx.nd.row_sparse_array(val, indices, shape, indices_type=np.int64, dtype=dtype)
return arr, (val, indices)
elif stype == 'csr':
assert(len(shape) == 2)
csr = sp.rand(shape[0], shape[1], density=density, format='csr')
result = mx.nd.csr(csr.data, csr.indptr, csr.indices, shape)
result = mx.nd.csr_matrix(csr.data, csr.indptr, csr.indices, shape, dtype=dtype)
return result, (csr.indptr, csr.indices, csr.data)
else:
assert(False), "unknown storage type"


def rand_ndarray(shape, stype, density=None):
def rand_ndarray(shape, stype, density=None, dtype=None):
if stype == 'default':
arr = mx.nd.array(random_arrays(shape))
arr = mx.nd.array(random_arrays(shape), dtype=dtype)
else:
arr, _ = rand_sparse_ndarray(shape, stype, density=density)
arr, _ = rand_sparse_ndarray(shape, stype, density=density, dtype=dtype)
return arr


Expand Down
18 changes: 18 additions & 0 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,24 @@ int MXInvokeCachedOp(CachedOpHandle handle,
API_END();
}

int MXInvokeCachedOpEx(CachedOpHandle handle,
int num_inputs,
NDArrayHandle *inputs,
int *num_outputs,
NDArrayHandle **outputs,
const int **out_stypes) { // outputs storage types
API_BEGIN();
MXInvokeCachedOp(handle, num_inputs, inputs, num_outputs, outputs);
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
NDArray** output_nds = reinterpret_cast<NDArray**>(*outputs);
ret->out_types.resize(*num_outputs);
for (int i = 0; i < *num_outputs; ++i) {
ret->out_types[i] = output_nds[i]->storage_type();
}
*out_stypes = dmlc::BeginPtr(ret->out_types);
API_END();
}

int MXAutogradSetIsTraining(int is_training, int* prev) {
API_BEGIN();
*prev = AutogradRuntime::Get()->SetIsTraining(static_cast<bool>(is_training));
Expand Down
2 changes: 1 addition & 1 deletion src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ inline bool GetDefaultBlobs(const std::vector<NDArray>& src,
for (size_t i = 0; i < src.size(); i++) {
auto& nd = src[i];
if (nd.storage_type() != kDefaultStorage) {
NDArray temp(nd.shape(), nd.ctx(), false);
NDArray temp(nd.shape(), nd.ctx(), false, nd.dtype());
temp_src->emplace_back(nd);
temp_dst->emplace_back(temp);
blobs->emplace_back(temp.data());
Expand Down
6 changes: 0 additions & 6 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1413,9 +1413,6 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning;
#else
bool profiling = false;
#endif
#if EXECUTOR_DEBUG
LOG(INFO) << "Run node " << nid << " - " << seg_op.topo_end - 1;
#endif
Engine::Get()->Push(seg_op.opr, seg_op.ctx, 0, profiling);
nid = seg_op.topo_end - 1;
Expand All @@ -1440,9 +1437,6 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning;
#else
bool profiling = false;
#endif
#if EXECUTOR_DEBUG
LOG(INFO) << "Run node " << nid;
#endif
Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling);
} else {
Expand Down
15 changes: 0 additions & 15 deletions src/operator/elemwise_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,21 +138,6 @@ inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs,
attrs, in_attrs, out_attrs);
}

inline bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs,
const Context& ctx,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
// TODO(junwu): add ctx info into storage inference logic
CHECK_EQ(in_attrs->size(), static_cast<size_t>(2)) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), static_cast<size_t>(1)) << " in operator " << attrs.name;
auto &in = *in_attrs;
auto &out = *out_attrs;
CHECK_NE(in[1], kUndefinedStorage) << "rhs storage type must be known";
if (in[0] == kUndefinedStorage) in[0] = in[1];
if (out[0] == kUndefinedStorage) out[0] = in[1];
return true;
}

// Transfer gradient and input to FGradient function
struct ElemwiseGradUseIn {
const char *op_name;
Expand Down
5 changes: 2 additions & 3 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
} else if (weight_stype == kRowSparseStorage && grad_stype == kDefaultStorage) {
NDArray out = outputs[0];
SGDUpdateRspDnsImpl<xpu>(param, ctx, inputs[0], inputs[1].data(), req[0], &out);
} else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage) {
} else {
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, SGDUpdate<xpu>, "SGDUpdate");
}
}
Expand Down Expand Up @@ -600,8 +600,7 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
mom_stype == kRowSparseStorage) {
NDArray out = outputs[0];
SGDMomUpdateRspDnsImpl<xpu>(param, ctx, weight, grad.data(), mom, req[0], &out);
} else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage &&
mom_stype == kDefaultStorage) {
} else {
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, SGDMomUpdate<xpu>, "SGDMomUpdate");
}
}
Expand Down
Loading

0 comments on commit 57da011

Please sign in to comment.