Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Refactor sparse tensor code #6955

Merged
merged 5 commits into from
Jul 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 81 additions & 59 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,17 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
const mx_uint *shape,
mx_uint ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
mx_uint num_aux,
int *aux_type,
mx_uint *aux_ndims,
const mx_uint *aux_shape,
NDArrayHandle *out);
const mx_uint *shape,
mx_uint ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
mx_uint num_aux,
int *aux_type,
mx_uint *aux_ndims,
const mx_uint *aux_shape,
NDArrayHandle *out);

/*!
* \brief create a NDArray handle that is loaded from raw bytes.
Expand Down Expand Up @@ -406,7 +406,7 @@ MXNET_DLL int MXNDArrayAt(NDArrayHandle handle,
* \brief get the storage type of the array
*/
MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle,
int *out_storage_type);
int *out_storage_type);

/*!
* \brief Reshape the NDArray.
Expand Down Expand Up @@ -605,6 +605,28 @@ MXNET_DLL int MXImperativeInvoke(AtomicSymbolCreator creator,
int num_params,
const char **param_keys,
const char **param_vals);
/*!
* \brief invoke a nnvm op and imperative function
* \param creator the op
* \param num_inputs number of input NDArrays
* \param inputs input NDArrays
* \param num_outputs number of output NDArrays
* \param outputs output NDArrays
* \param num_params number of keyword parameters
* \param param_keys keys for keyword parameters
* \param param_vals values for keyword parameters
* \param out_stypes output ndarrays' stypes
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXImperativeInvokeEx(AtomicSymbolCreator creator,
int num_inputs,
NDArrayHandle *inputs,
int *num_outputs,
NDArrayHandle **outputs,
int num_params,
const char **param_keys,
const char **param_vals,
const int **out_stypes);
/*!
* \brief set whether to record operator for autograd
* \param is_train 1 when training, 0 when testing
Expand Down Expand Up @@ -1019,20 +1041,20 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_uint *arg_ind_ptr,
const mx_uint *arg_shape_data,
mx_uint *in_shape_size,
const mx_uint **in_shape_ndim,
const mx_uint ***in_shape_data,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
mx_uint *aux_shape_size,
const mx_uint **aux_shape_ndim,
const mx_uint ***aux_shape_data,
int *complete);
mx_uint num_args,
const char** keys,
const mx_uint *arg_ind_ptr,
const mx_uint *arg_shape_data,
mx_uint *in_shape_size,
const mx_uint **in_shape_ndim,
const mx_uint ***in_shape_data,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
mx_uint *aux_shape_size,
const mx_uint **aux_shape_ndim,
const mx_uint ***aux_shape_data,
int *complete);

/*!
* \brief infer type of unknown input types given the known one.
Expand Down Expand Up @@ -1215,39 +1237,39 @@ MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle,
ExecutorHandle *out);

MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
const mx_uint num_g2c_keys,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const mx_uint provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
const char** shared_buffer_name_list,
NDArrayHandle* shared_buffer_handle_list,
const char*** updated_shared_buffer_name_list,
NDArrayHandle** updated_shared_buffer_handle_list,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
mx_uint* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);
int dev_type,
int dev_id,
const mx_uint num_g2c_keys,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const mx_uint provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
const char** shared_buffer_name_list,
NDArrayHandle* shared_buffer_handle_list,
const char*** updated_shared_buffer_name_list,
NDArrayHandle** updated_shared_buffer_handle_list,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
mx_uint* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);
/*!
* \brief set a call back to notify the completion of operation
*/
Expand Down
5 changes: 0 additions & 5 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@

namespace mxnet {

namespace ndarray {
template<typename from_xpu, typename to_xpu>
void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, RunContext ctx);
};

namespace autograd {
class AGNode;

Expand Down
8 changes: 4 additions & 4 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ using FCompute = std::function<void (const nnvm::NodeAttrs& attrs,
* Dispatched only when operators process non-default storage inputs or outputs
*/
using FComputeEx = std::function<void (const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs)>;
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs)>;

using FInferStorageType = std::function<bool (const NodeAttrs& attrs,
const Context& ctx,
Expand Down
7 changes: 1 addition & 6 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from . import base
from . import contrib
from . import ndarray
from . import sparse_ndarray
from . import ndarray_utils
from . import ndarray as nd
from . import name
# use mx.sym as short for symbol
from . import symbol as sym
Expand All @@ -18,10 +17,6 @@
from . import io
from . import recordio
from . import operator
# use mx.nd as short for mx.ndarray
from . import ndarray as nd
from . import sparse_ndarray as sparse_nd
from . import ndarray_utils as nd_utils
# use mx.rnd as short for mx.random
from . import random as rnd
from . import random
Expand Down
24 changes: 20 additions & 4 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@
from .common import CachedOp


_STORAGE_TYPE_ID_TO_STR = {
-1 : 'undefined',
0 : 'default',
1 : 'row_sparse',
2 : 'csr',
}


class NDArrayBase(object):
"""Base data structure for ndarray"""
__slots__ = ["handle", "writable"]
# pylint: disable= no-member

def __init__(self, handle, writable=True):
"""initialize a new NDArray

Expand Down Expand Up @@ -62,22 +71,29 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
output_vars = ctypes.POINTER(NDArrayHandle)()
num_output = ctypes.c_int(0)

check_call(_LIB.MXImperativeInvoke(
# return output stypes to avoid the c_api call for checking
# a handle's stype in _ndarray_cls
out_stypes = ctypes.POINTER(ctypes.c_int)()

check_call(_LIB.MXImperativeInvokeEx(
ctypes.c_void_p(handle),
ctypes.c_int(len(ndargs)),
c_array(NDArrayHandle, [arr.handle for arr in ndargs]),
ctypes.byref(num_output),
ctypes.byref(output_vars),
ctypes.c_int(len(keys)),
c_array(ctypes.c_char_p, [c_str(key) for key in keys]),
c_array(ctypes.c_char_p, [c_str(str(val)) for val in vals])))
c_array(ctypes.c_char_p, [c_str(str(val)) for val in vals]),
ctypes.byref(out_stypes)))

if original_output is not None:
return original_output
if num_output.value == 1:
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle))
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]])
else:
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle))
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]])
for i in range(num_output.value)]


Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .base import mx_uint, NDArrayHandle, ExecutorHandle
from .base import check_call, c_array, py_str
from .ndarray import NDArray
from .sparse_ndarray import _ndarray_cls
from .ndarray import _ndarray_cls
from . import ndarray as nd

# those functions are not used here, we just import them to keep backward compatibility
Expand Down
6 changes: 3 additions & 3 deletions python/mxnet/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from .base import numeric_types
from . import ndarray as nd
from . import _ndarray_internal as _internal
from ._ndarray_internal import _cvimresize as imresize
from ._ndarray_internal import _cvcopyMakeBorder as copyMakeBorder
from .ndarray import _internal
from .ndarray._internal import _cvimresize as imresize
from .ndarray._internal import _cvcopyMakeBorder as copyMakeBorder
from . import io
from . import recordio

Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .base import mx_real_t
from .base import check_call, build_param_doc as _build_param_doc
from .ndarray import NDArray
from .sparse_ndarray import _ndarray_cls
from .ndarray import _ndarray_cls
from .ndarray import array
from .ndarray import concatenate

Expand Down
7 changes: 3 additions & 4 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
import logging
import warnings

import mxnet as mx
from .. import context as ctx
from .. import ndarray as nd
from .. import optimizer as opt

from .executor_group import DataParallelExecutorGroup
from ..model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore
from ..model import load_checkpoint
from ..initializer import Uniform, InitDesc
from ..ndarray import zeros

from .base_module import BaseModule, _check_input_names, _parse_data_desc

Expand Down Expand Up @@ -399,13 +398,13 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
else:
assert self._arg_params is None and self._aux_params is None
param_arrays = [
mx.nd.zeros(shape=x[0].shape, dtype=x[0].dtype, storage_type=x[0].storage_type)
zeros(shape=x[0].shape, dtype=x[0].dtype, stype=x[0].stype)
for x in self._exec_group.param_arrays
]
self._arg_params = {name:arr for name, arr in zip(self._param_names, param_arrays)}

aux_arrays = [
nd.zeros(x[0].shape, dtype=x[0].dtype)
zeros(x[0].shape, dtype=x[0].dtype)
for x in self._exec_group.aux_arrays
]
self._aux_params = {name:arr for name, arr in zip(self._aux_names, aux_arrays)}
Expand Down
12 changes: 12 additions & 0 deletions python/mxnet/ndarray/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""ndarray module"""

from . import _internal
from . import op
from .op import CachedOp, invoke
from .ndarray import NDArray, array, concatenate, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
from .ndarray import empty, ones, add, arange, divide, equal, full, greater, greater_equal, imdecode
from .ndarray import lesser, lesser_equal, maximum, minimum, moveaxis, multiply, negative, not_equal
from .ndarray import onehot_encode, power, subtract, true_divide, waitall, _new_empty_handle
from .ndarray_utils import load, save, zeros
from .sparse_ndarray import _ndarray_cls
from .sparse_ndarray import csr, row_sparse, SparseNDArray, todense, RowSparseNDArray, CSRNDArray
Loading