diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 0633e92d6b90..bd519fd7f417 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -266,17 +266,17 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type, - const mx_uint *shape, - mx_uint ndim, - int dev_type, - int dev_id, - int delay_alloc, - int dtype, - mx_uint num_aux, - int *aux_type, - mx_uint *aux_ndims, - const mx_uint *aux_shape, - NDArrayHandle *out); + const mx_uint *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + mx_uint num_aux, + int *aux_type, + mx_uint *aux_ndims, + const mx_uint *aux_shape, + NDArrayHandle *out); /*! * \brief create a NDArray handle that is loaded from raw bytes. @@ -406,7 +406,7 @@ MXNET_DLL int MXNDArrayAt(NDArrayHandle handle, * \brief get the storage type of the array */ MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle, - int *out_storage_type); + int *out_storage_type); /*! * \brief Reshape the NDArray. @@ -604,8 +604,29 @@ MXNET_DLL int MXImperativeInvoke(AtomicSymbolCreator creator, NDArrayHandle **outputs, int num_params, const char **param_keys, - const char **param_vals, - const int** out_stypes); + const char **param_vals); +/*! + * \brief invoke a nnvm op and imperative function + * \param creator the op + * \param num_inputs number of input NDArrays + * \param inputs input NDArrays + * \param num_outputs number of output NDArrays + * \param outputs output NDArrays + * \param num_params number of keyword parameters + * \param param_keys keys for keyword parameters + * \param param_vals values for keyword parameters + * \param out_stypes output ndarrays' stypes + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXImperativeInvokeEx(AtomicSymbolCreator creator, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs, + int num_params, + const char **param_keys, + const char **param_vals, + const int **out_stypes); /*! * \brief set whether to record operator for autograd * \param is_train 1 when training, 0 when testing @@ -1020,20 +1041,20 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym, - mx_uint num_args, - const char** keys, - const mx_uint *arg_ind_ptr, - const mx_uint *arg_shape_data, - mx_uint *in_shape_size, - const mx_uint **in_shape_ndim, - const mx_uint ***in_shape_data, - mx_uint *out_shape_size, - const mx_uint **out_shape_ndim, - const mx_uint ***out_shape_data, - mx_uint *aux_shape_size, - const mx_uint **aux_shape_ndim, - const mx_uint ***aux_shape_data, - int *complete); + mx_uint num_args, + const char** keys, + const mx_uint *arg_ind_ptr, + const mx_uint *arg_shape_data, + mx_uint *in_shape_size, + const mx_uint **in_shape_ndim, + const mx_uint ***in_shape_data, + mx_uint *out_shape_size, + const mx_uint **out_shape_ndim, + const mx_uint ***out_shape_data, + mx_uint *aux_shape_size, + const mx_uint **aux_shape_ndim, + const mx_uint ***aux_shape_data, + int *complete); /*! * \brief infer type of unknown input types given the known one. @@ -1216,39 +1237,39 @@ MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle, ExecutorHandle *out); MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, - int dev_type, - int dev_id, - const mx_uint num_g2c_keys, - const char** g2c_keys, - const int* g2c_dev_types, - const int* g2c_dev_ids, - const mx_uint provided_grad_req_list_len, - const char** provided_grad_req_names, - const char** provided_grad_req_types, - const mx_uint num_provided_arg_shapes, - const char** provided_arg_shape_names, - const mx_uint* provided_arg_shape_data, - const mx_uint* provided_arg_shape_idx, - const mx_uint num_provided_arg_dtypes, - const char** provided_arg_dtype_names, - const int* provided_arg_dtypes, - const mx_uint num_provided_arg_stypes, - const char** provided_arg_stype_names, - const int* provided_arg_stypes, - const mx_uint num_shared_arg_names, - const char** shared_arg_name_list, - int* shared_buffer_len, - const char** shared_buffer_name_list, - NDArrayHandle* shared_buffer_handle_list, - const char*** updated_shared_buffer_name_list, - NDArrayHandle** updated_shared_buffer_handle_list, - mx_uint* num_in_args, - NDArrayHandle** in_args, - NDArrayHandle** arg_grads, - mx_uint* num_aux_states, - NDArrayHandle** aux_states, - ExecutorHandle shared_exec_handle, - ExecutorHandle* out); + int dev_type, + int dev_id, + const mx_uint num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const mx_uint provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const mx_uint num_provided_arg_shapes, + const char** provided_arg_shape_names, + const mx_uint* provided_arg_shape_data, + const mx_uint* provided_arg_shape_idx, + const mx_uint num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const mx_uint num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const mx_uint num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + mx_uint* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + mx_uint* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out); /*! * \brief set a call back to notify the completion of operation */ diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index e1e6269e3d8b..662b45546cb4 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -29,11 +29,6 @@ namespace mxnet { -namespace ndarray { -template -void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, RunContext ctx); -}; - namespace autograd { class AGNode; diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 6de6e6bf479c..eaa87374c8d7 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -67,10 +67,10 @@ using FCompute = std::function& inputs, - const std::vector& req, - const std::vector& outputs)>; + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs)>; using FInferStorageType = std::function= (3, 0): - from ._cy3.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR - from ._cy3.ndarray import invoke, CachedOp, _imperative_invoke - else: - from ._cy2.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR - from ._cy2.ndarray import invoke, CachedOp, _imperative_invoke -except ImportError: - if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: - raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") - from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR - from ._ctypes.ndarray import invoke, CachedOp, _imperative_invoke -# pylint: enable=unused-import +from ..base import _LIB, numeric_types +from ..base import c_array, mx_real_t +from ..base import mx_uint, NDArrayHandle, check_call +from ..base import ctypes2buffer +from ..context import Context +from . import _internal +from .op import NDArrayBase, _STORAGE_TYPE_ID_TO_STR +from . import broadcast_add, broadcast_mul, transpose, broadcast_not_equal, broadcast_power +from . import broadcast_sub, broadcast_div, broadcast_to, broadcast_equal, cast_storage +from . import broadcast_greater, broadcast_greater_equal, broadcast_lesser, broadcast_lesser_equal + # pylint: disable= no-member _DTYPE_NP_TO_MX = { @@ -743,9 +722,7 @@ def dtype(self): @property def stype(self): - if self._stype is None: - self._stype = _storage_type(self.handle) - return self._stype + return _storage_type(self.handle) @property # pylint: disable= invalid-name, undefined-variable @@ -973,11 +950,11 @@ def backward(self, out_grad=None, retain_graph=False): def _to_csr(self): # pylint: disable=undefined-variable - return cast_storage(self, storage_type='csr') + return cast_storage(self, stype='csr') def _to_rsp(self): # pylint: disable=undefined-variable - return cast_storage(self, storage_type='row_sparse') + return cast_storage(self, stype='row_sparse') def onehot_encode(indices, out): """One-hot encoding indices into matrix out. @@ -2196,160 +2173,38 @@ def imdecode(str_img, clip_rect=(0, 0, 0, 0), out=None, index=0, channels=3, mea out=out) -# pylint: disable=too-many-locals, invalid-name -def _make_ndarray_function(handle, name): - """Create a NDArray function from the FunctionHandle.""" - real_name = ctypes.c_char_p() - desc = ctypes.c_char_p() - num_args = mx_uint() - arg_names = ctypes.POINTER(ctypes.c_char_p)() - arg_types = ctypes.POINTER(ctypes.c_char_p)() - arg_descs = ctypes.POINTER(ctypes.c_char_p)() - key_var_num_args = ctypes.c_char_p() - ret_type = ctypes.c_char_p() - - check_call(_LIB.MXSymbolGetAtomicSymbolInfo( - handle, ctypes.byref(real_name), ctypes.byref(desc), - ctypes.byref(num_args), - ctypes.byref(arg_names), - ctypes.byref(arg_types), - ctypes.byref(arg_descs), - ctypes.byref(key_var_num_args), - ctypes.byref(ret_type))) - narg = int(num_args.value) - arg_names = [py_str(arg_names[i]) for i in range(narg)] - arg_types = [py_str(arg_types[i]) for i in range(narg)] - func_name = name - key_var_num_args = py_str(key_var_num_args.value) - ret_type = py_str(ret_type.value) if ret_type.value is not None else '' - doc_str = _build_doc(func_name, - py_str(desc.value), - arg_names, - arg_types, - [py_str(arg_descs[i]) for i in range(narg)], - key_var_num_args, - ret_type) - - dtype_name = None - arr_name = None - ndsignature = [] - signature = [] - ndarg_names = [] - kwarg_names = [] - for i in range(narg): - name, atype = arg_names[i], arg_types[i] - if name == 'dtype': - dtype_name = name - signature.append('%s=_Null'%name) - elif atype.startswith('NDArray') or atype.startswith('Symbol'): - assert not arr_name, \ - "Op can only have one argument with variable " \ - "size and it must be the last argument." - if atype.endswith('[]'): - ndsignature.append('*%s'%name) - arr_name = name - else: - ndsignature.append('%s=None'%name) - ndarg_names.append(name) - else: - signature.append('%s=_Null'%name) - kwarg_names.append(name) - #signature.append('is_train=False') - signature.append('out=None') - signature.append('name=None') - signature.append('**kwargs') - signature = ndsignature + signature - - code = [] - if arr_name: - code.append(""" -def %s(*%s, **kwargs):"""%(func_name, arr_name)) - code.append(""" - ndargs = [] - for i in {}: - assert isinstance(i, NDArrayBase), \\ - "Positional arguments must have NDArray type, " \\ - "but got %s"%str(i) - ndargs.append(i)""".format(arr_name)) - if dtype_name is not None: - code.append(""" - if '%s' in kwargs: - kwargs['%s'] = np.dtype(kwargs['%s']).name"""%( - dtype_name, dtype_name, dtype_name)) - code.append(""" - _ = kwargs.pop('name', None) - out = kwargs.pop('out', None) - keys = list(kwargs.keys()) - vals = list(kwargs.values())""") - else: - code.append(""" -def %s(%s): - ndargs = [] - keys = list(kwargs.keys()) - vals = list(kwargs.values())"""%(func_name, ', '.join(signature))) - # NDArray args - for name in ndarg_names: # pylint: disable=redefined-argument-from-local - code.append(""" - if {name} is not None: - assert isinstance({name}, NDArrayBase), \\ - "Argument {name} must have NDArray type, but got %s"%str({name}) - ndargs.append({name})""".format(name=name)) - # kwargs - for name in kwarg_names: # pylint: disable=redefined-argument-from-local - code.append(""" - if %s is not _Null: - keys.append('%s') - vals.append(%s)"""%(name, name, name)) - # dtype - if dtype_name is not None: - code.append(""" - if %s is not _Null: - keys.append('%s') - vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name)) - - code.append(""" - return _imperative_invoke(%d, ndargs, keys, vals, out)"""%( - handle.value)) - - local = {} - exec(''.join(code), None, local) # pylint: disable=exec-used - ndarray_function = local[func_name] - ndarray_function.__name__ = func_name - ndarray_function.__doc__ = doc_str - ndarray_function.__module__ = 'mxnet.ndarray' - return ndarray_function - - -# pylint: enable=too-many-locals, invalid-name -def _init_ndarray_module(root_namespace): - """List and add all the ndarray functions to current module.""" - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - - check_call(_LIB.MXListAllOpNames(ctypes.byref(size), - ctypes.byref(plist))) - op_names = [] - for i in range(size.value): - op_names.append(py_str(plist[i])) - - module_obj = _sys.modules["%s.ndarray" % root_namespace] - module_internal = _sys.modules["%s._ndarray_internal" % root_namespace] - module_contrib = _sys.modules["%s.contrib.ndarray" % root_namespace] - for name in op_names: - hdl = OpHandle() - check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) - function = _make_ndarray_function(hdl, name) - if function.__name__.startswith('_contrib_'): - function.__name__ = function.__name__[9:] - function.__module__ = 'mxnet.contrib.ndarray' - setattr(module_contrib, function.__name__, function) - elif function.__name__.startswith('_'): - setattr(module_internal, function.__name__, function) - else: - setattr(module_obj, function.__name__, function) +def _zeros_ndarray(shape, ctx=None, dtype=None, **kwargs): + """Returns a new array filled with all zeros, with the given shape and type. + + Parameters + ---------- + shape : int or tuple of int + The shape of the empty array. + ctx : Context, optional + An optional device context (default is the current default context). + dtype : str or numpy.dtype, optional + An optional value type (default is `float32`). + out : NDArray, optional + The output NDArray (default is `None`). -# register backend operators in mx.nd -_init_ndarray_module("mxnet") + Returns + ------- + NDArray + A created array -# from .base import add_fileline_to_docstring -# add_fileline_to_docstring(__name__) + Examples + -------- + >>> mx.nd.zeros(1).asnumpy() + array([ 0.], dtype=float32) + >>> mx.nd.zeros((1,2), mx.gpu(0)) + + >>> mx.nd.zeros((1,2), mx.gpu(0), 'float16').asnumpy() + array([[ 0., 0.]], dtype=float16) + """ + # pylint: disable= unused-argument + if ctx is None: + ctx = Context.default_ctx + dtype = mx_real_t if dtype is None else dtype + # pylint: disable= no-member, protected-access + return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) + # pylint: enable= no-member, protected-access diff --git a/python/mxnet/ndarray/ndarray_utils.py b/python/mxnet/ndarray/ndarray_utils.py new file mode 100644 index 000000000000..2516372d1b55 --- /dev/null +++ b/python/mxnet/ndarray/ndarray_utils.py @@ -0,0 +1,99 @@ +# coding: utf-8 +"""Utility functions for NDArray and SparseNDArray.""" +import ctypes + +from ..base import _LIB, check_call, py_str, c_str, string_types, mx_uint, NDArrayHandle, c_array +from .ndarray import NDArray, _zeros_ndarray +from .sparse_ndarray import _ndarray_cls, _zeros_sparse_ndarray + + +def zeros(shape, ctx=None, dtype=None, stype=None, aux_types=None, **kwargs): + if stype is None: + return _zeros_ndarray(shape, ctx, dtype, **kwargs) + else: + return _zeros_sparse_ndarray(stype, shape, ctx, dtype, aux_types, **kwargs) + + +def load(fname): + """Loads an array from file. + + See more details in ``save``. + + Parameters + ---------- + fname : str + The filename. + + Returns + ------- + list of NDArray or dict of str to NDArray + Loaded data. + """ + if not isinstance(fname, string_types): + raise TypeError('fname required to be a string') + out_size = mx_uint() + out_name_size = mx_uint() + handles = ctypes.POINTER(NDArrayHandle)() + names = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.MXNDArrayLoad(c_str(fname), + ctypes.byref(out_size), + ctypes.byref(handles), + ctypes.byref(out_name_size), + ctypes.byref(names))) + if out_name_size.value == 0: + return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)] + else: + assert out_name_size.value == out_size.value + return dict( + (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i]))) + for i in range(out_size.value)) + + +def save(fname, data): + """Saves a list of arrays or a dict of str->array to file. + + Examples of filenames: + + - ``/path/to/file`` + - ``s3://my-bucket/path/to/file`` (if compiled with AWS S3 supports) + - ``hdfs://path/to/file`` (if compiled with HDFS supports) + + Parameters + ---------- + fname : str + The filename. + data : list of ``NDArray` or dict of str to ``NDArray`` + The data to save. + + Examples + -------- + >>> x = mx.nd.zeros((2,3)) + >>> y = mx.nd.ones((1,4)) + >>> mx.nd.save('my_list', [x,y]) + >>> mx.nd.save('my_dict', {'x':x, 'y':y}) + >>> mx.nd.load('my_list') + [, ] + >>> mx.nd.load('my_dict') + {'y': , 'x': } + """ + handles = [] + if isinstance(data, dict): + keys = [] + for key, val in data.items(): + if not isinstance(key, string_types): + raise TypeError('save only accept dict str->NDArray or list of NDArray') + if not isinstance(val, NDArray): + raise TypeError('save only accept dict str->NDArray or list of NDArray') + keys.append(c_str(key)) + handles.append(val.handle) + keys = c_array(ctypes.c_char_p, keys) + else: + for val in data: + if not isinstance(val, NDArray): + raise TypeError('save only accept dict str->NDArray or list of NDArray') + handles.append(val.handle) + keys = None + check_call(_LIB.MXNDArraySave(c_str(fname), + mx_uint(len(handles)), + c_array(NDArrayHandle, handles), + keys)) diff --git a/python/mxnet/ndarray/op.py b/python/mxnet/ndarray/op.py new file mode 100644 index 000000000000..3c81bce4b261 --- /dev/null +++ b/python/mxnet/ndarray/op.py @@ -0,0 +1,189 @@ +"""Register backend ops in mxnet.ndarray namespace""" + +import sys as _sys +import os as _os +import ctypes +import numpy as np # pylint: disable=unused-import + +from ..ndarray_doc import _build_doc + +# Use different verison of SymbolBase +# When possible, use cython to speedup part of computation. +# pylint: disable=unused-import +try: + if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: + from .._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _STORAGE_TYPE_ID_TO_STR + from .._ctypes.ndarray import invoke, CachedOp, _imperative_invoke + elif _sys.version_info >= (3, 0): + from .._cy3.ndarray import NDArrayBase, _set_ndarray_class,\ + _imperative_invoke, _STORAGE_TYPE_ID_TO_STR + from .._cy3.ndarray import invoke, CachedOp, _imperative_invoke + else: + from .._cy2.ndarray import NDArrayBase, _set_ndarray_class,\ + _imperative_invoke, _STORAGE_TYPE_ID_TO_STR + from .._cy2.ndarray import invoke, CachedOp, _imperative_invoke +except ImportError: + if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: + raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") + from .._ctypes.ndarray import NDArrayBase, _set_ndarray_class,\ + _imperative_invoke, _STORAGE_TYPE_ID_TO_STR + from .._ctypes.ndarray import invoke, CachedOp, _imperative_invoke + +from ..base import mx_uint, check_call, _LIB, py_str, OpHandle, c_str, _Null +# pylint: enable=unused-import + + +# pylint: disable=too-many-locals, invalid-name +def _make_ndarray_function(handle, name): + """Create a NDArray function from the FunctionHandle.""" + real_name = ctypes.c_char_p() + desc = ctypes.c_char_p() + num_args = mx_uint() + arg_names = ctypes.POINTER(ctypes.c_char_p)() + arg_types = ctypes.POINTER(ctypes.c_char_p)() + arg_descs = ctypes.POINTER(ctypes.c_char_p)() + key_var_num_args = ctypes.c_char_p() + ret_type = ctypes.c_char_p() + + check_call(_LIB.MXSymbolGetAtomicSymbolInfo( + handle, ctypes.byref(real_name), ctypes.byref(desc), + ctypes.byref(num_args), + ctypes.byref(arg_names), + ctypes.byref(arg_types), + ctypes.byref(arg_descs), + ctypes.byref(key_var_num_args), + ctypes.byref(ret_type))) + narg = int(num_args.value) + arg_names = [py_str(arg_names[i]) for i in range(narg)] + arg_types = [py_str(arg_types[i]) for i in range(narg)] + func_name = name + key_var_num_args = py_str(key_var_num_args.value) + ret_type = py_str(ret_type.value) if ret_type.value is not None else '' + doc_str = _build_doc(func_name, + py_str(desc.value), + arg_names, + arg_types, + [py_str(arg_descs[i]) for i in range(narg)], + key_var_num_args, + ret_type) + + dtype_name = None + arr_name = None + ndsignature = [] + signature = [] + ndarg_names = [] + kwarg_names = [] + for i in range(narg): + name, atype = arg_names[i], arg_types[i] + if name == 'dtype': + dtype_name = name + signature.append('%s=_Null'%name) + elif atype.startswith('NDArray') or atype.startswith('Symbol'): + assert not arr_name, \ + "Op can only have one argument with variable " \ + "size and it must be the last argument." + if atype.endswith('[]'): + ndsignature.append('*%s'%name) + arr_name = name + else: + ndsignature.append('%s=None'%name) + ndarg_names.append(name) + else: + signature.append('%s=_Null'%name) + kwarg_names.append(name) + #signature.append('is_train=False') + signature.append('out=None') + signature.append('name=None') + signature.append('**kwargs') + signature = ndsignature + signature + + code = [] + if arr_name: + code.append(""" +def %s(*%s, **kwargs):"""%(func_name, arr_name)) + code.append(""" + ndargs = [] + for i in {}: + assert isinstance(i, NDArrayBase), \\ + "Positional arguments must have NDArray type, " \\ + "but got %s"%str(i) + ndargs.append(i)""".format(arr_name)) + if dtype_name is not None: + code.append(""" + if '%s' in kwargs: + kwargs['%s'] = np.dtype(kwargs['%s']).name"""%( + dtype_name, dtype_name, dtype_name)) + code.append(""" + _ = kwargs.pop('name', None) + out = kwargs.pop('out', None) + keys = list(kwargs.keys()) + vals = list(kwargs.values())""") + else: + code.append(""" +def %s(%s): + ndargs = [] + keys = list(kwargs.keys()) + vals = list(kwargs.values())"""%(func_name, ', '.join(signature))) + # NDArray args + for name in ndarg_names: # pylint: disable=redefined-argument-from-local + code.append(""" + if {name} is not None: + assert isinstance({name}, NDArrayBase), \\ + "Argument {name} must have NDArray type, but got %s"%str({name}) + ndargs.append({name})""".format(name=name)) + # kwargs + for name in kwarg_names: # pylint: disable=redefined-argument-from-local + code.append(""" + if %s is not _Null: + keys.append('%s') + vals.append(%s)"""%(name, name, name)) + # dtype + if dtype_name is not None: + code.append(""" + if %s is not _Null: + keys.append('%s') + vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name)) + + code.append(""" + return _imperative_invoke(%d, ndargs, keys, vals, out)"""%( + handle.value)) + + local = {} + exec(''.join(code), None, local) # pylint: disable=exec-used + ndarray_function = local[func_name] + ndarray_function.__name__ = func_name + ndarray_function.__doc__ = doc_str + ndarray_function.__module__ = 'mxnet.ndarray' + return ndarray_function + + +# pylint: enable=too-many-locals, invalid-name +def _init_ndarray_module(root_namespace): + """List and add all the ndarray functions to current module.""" + plist = ctypes.POINTER(ctypes.c_char_p)() + size = ctypes.c_uint() + + check_call(_LIB.MXListAllOpNames(ctypes.byref(size), + ctypes.byref(plist))) + op_names = [] + for i in range(size.value): + op_names.append(py_str(plist[i])) + + module_obj = _sys.modules["%s.ndarray" % root_namespace] + module_internal = _sys.modules["%s.ndarray._internal" % root_namespace] + module_contrib = _sys.modules["%s.contrib.ndarray" % root_namespace] + for name in op_names: + hdl = OpHandle() + check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) + function = _make_ndarray_function(hdl, name) + if function.__name__.startswith('_contrib_'): + function.__name__ = function.__name__[9:] + function.__module__ = 'mxnet.contrib.ndarray' + setattr(module_contrib, function.__name__, function) + elif function.__name__.startswith('_'): + setattr(module_internal, function.__name__, function) + else: + setattr(module_obj, function.__name__, function) + +# register backend operators in mx.nd +_init_ndarray_module("mxnet") diff --git a/python/mxnet/sparse_ndarray.py b/python/mxnet/ndarray/sparse_ndarray.py similarity index 80% rename from python/mxnet/sparse_ndarray.py rename to python/mxnet/ndarray/sparse_ndarray.py index a438b4d6ec7d..720d44586a74 100644 --- a/python/mxnet/sparse_ndarray.py +++ b/python/mxnet/ndarray/sparse_ndarray.py @@ -15,31 +15,32 @@ # import operator import numpy as np -import mxnet as mx -from .base import _LIB, numeric_types -from .base import c_array, mx_real_t -from .base import mx_uint, NDArrayHandle, check_call -from .context import Context -from . import _ndarray_internal as _internal +from ..base import _LIB, numeric_types +from ..base import c_array, mx_real_t +from ..base import mx_uint, NDArrayHandle, check_call +from ..context import Context +from . import _internal from . import ndarray from .ndarray import _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP from .ndarray import _STORAGE_TYPE_STR_TO_ID -from .ndarray import NDArray, _storage_type +from .ndarray import NDArray, _storage_type, _zeros_ndarray +from . import cast_storage +from . import slice as nd_slice # Use different verison of SymbolBase # When possible, use cython to speedup part of computation. # pylint: disable=unused-import try: if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: - from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class + from .._ctypes.ndarray import NDArrayBase, _set_ndarray_class elif _sys.version_info >= (3, 0): - from ._cy3.ndarray import NDArrayBase, _set_ndarray_class + from .._cy3.ndarray import NDArrayBase, _set_ndarray_class else: - from ._cy2.ndarray import NDArrayBase, _set_ndarray_class + from .._cy2.ndarray import NDArrayBase, _set_ndarray_class except ImportError: if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") - from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class + from .._ctypes.ndarray import NDArrayBase, _set_ndarray_class # pylint: enable=unused-import _STORAGE_AUX_TYPES = { @@ -48,7 +49,7 @@ } -def _new_alloc_handle(storage_type, shape, ctx, delay_alloc, dtype, aux_types, aux_shapes=None): +def _new_alloc_handle(stype, shape, ctx, delay_alloc, dtype, aux_types, aux_shapes=None): """Return a new handle with specified storage type, shape, dtype and context. Empty handle is only used to hold results @@ -65,7 +66,7 @@ def _new_alloc_handle(storage_type, shape, ctx, delay_alloc, dtype, aux_types, a aux_shapes = sum(aux_shapes, ()) num_aux = mx_uint(len(aux_types)) check_call(_LIB.MXNDArrayCreateSparseEx( - ctypes.c_int(int(_STORAGE_TYPE_STR_TO_ID[storage_type])), + ctypes.c_int(int(_STORAGE_TYPE_STR_TO_ID[stype])), c_array(mx_uint, shape), mx_uint(len(shape)), ctypes.c_int(ctx.device_typeid), @@ -114,13 +115,13 @@ def __setitem__(self, key, value): Examples -------- - >>> src = mx.sparse_nd.row_sparse(data, indices, (3,3)) + >>> src = mx.nd.row_sparse([[1, 0, 2], [4, 5, 6]], [0, 2], (3,3)) >>> src.asnumpy() array([[ 1., 0., 2.], [ 0., 0., 0.], [ 4., 5., 6.]], dtype=float32) >>> # assign SparseNDArray with same storage type - >>> x = mx.sparse_nd.zeros('row_sparse', (3,3)) + >>> x = mx.nd.zeros('row_sparse', (3,3)) >>> x[:] = src >>> x.asnumpy() array([[ 1., 0., 2.], @@ -168,6 +169,7 @@ def __getitem__(self, key): Examples -------- + >>> x = mx.nd.zeros((2, 3), stype='row_sparse') >>> x[:] = mx.nd.arange(0,6).reshape((2,3)) >>> x.asnumpy() array([[ 0., 1., 2.], @@ -186,7 +188,7 @@ def __getitem__(self, key): if key.start is not None or key.stop is not None: begin = key.start if key.start else 0 end = key.stop if key.stop else self.shape[0] - return ndarray.slice(self, begin=begin, end=end) + return nd_slice(self, begin=begin, end=end) else: return self if isinstance(key, tuple): @@ -217,7 +219,7 @@ def _aux_type(self, i): return _DTYPE_MX_TO_NP[aux_type.value] @property - def values(self): + def data(self): """The values array of the SparseNDArray. This is a read-only view of the values array. They reveal internal implementation details and should be used with care. @@ -228,7 +230,6 @@ def values(self): """ return self._data() - @property def _num_aux(self): ''' The number of aux data used to help store the sparse ndarray. @@ -241,7 +242,7 @@ def T(self): raise Exception('Transpose is not supported for SparseNDArray.') @property - def aux_types(self): + def _aux_types(self): """The data types of the aux data for the SparseNDArray. """ aux_types = [] @@ -264,13 +265,13 @@ def astype(self, dtype): The type of the returned array. Examples -------- - >>> x = mx.sparse_nd.zeros('row_sparse', (2,3), dtype='float32') + >>> x = mx.nd.zeros('row_sparse', (2,3), dtype='float32') >>> y = x.astype('int32') >>> y.dtype """ - res = mx.nd.zeros(shape=self.shape, ctx=self.context, - dtype=dtype, storage_type=self.stype) + res = _zeros_sparse_ndarray(shape=self.shape, ctx=self.context, + dtype=dtype, stype=self.stype) self.copyto(res) return res @@ -302,7 +303,7 @@ def copyto(self, other): return _internal._copyto(self, out=other) elif isinstance(other, Context): hret = _ndarray_cls(_new_alloc_handle(self.stype, self.shape, other, - True, self.dtype, self.aux_types)) + True, self.dtype, self._aux_types)) return _internal._copyto(self, out=hret) else: raise TypeError('copyto does not support type ' + str(type(other))) @@ -333,7 +334,18 @@ class CSRNDArray(SparseNDArray): row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored in values[indptr[i]:indptr[i+1]]. + Example + ------- + >>> a = mx.nd.array([[0, 1, 0], [2, 0, 0], [0, 0, 0], [0, 0, 3]]) + >>> a = a._to_csr() + >>> a.indices.asnumpy() + array([1, 0, 2]) + >>> a.indptr.asnumpy() + array([0, 1, 2, 2, 3]) + >>> a.data.asnumpy() + array([ 1., 2., 3.], dtype=float32) """ + def __reduce__(self): return CSRNDArray, (None,), super(CSRNDArray, self).__getstate__() @@ -376,6 +388,17 @@ class RowSparseNDArray(SparseNDArray): RowSparseNDArray is used principally in the definition of gradients for operations that have sparse gradients (e.g. SparseEmbedding). + + Examples + -------- + >>> import mxnet as mx + >>> dense = mx.nd.array([[1,2],[0,0],[3,0],[0,0]]) + >>> rsp = dense._to_rsp() + >>> rsp.indices.asnumpy() + array([0, 2], dtype=int32) + >>> rsp.data.asnumpy() + array([[ 1., 2.], + [ 3., 0.]], dtype=float32) """ def __reduce__(self): return RowSparseNDArray, (None,), super(RowSparseNDArray, self).__getstate__() @@ -406,12 +429,12 @@ def _prepare_src_array(src, dtype, default_dtype): return src, dtype -def csr(values, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, indices_type=None): +def csr(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, indices_type=None): """Creates a 2D array with compressed sparse row format. Parameters ---------- - values: array_like + data: array_like An object exposing the array interface, with shape [nnz], where D0 is the number of non-zero entries. indptr: array_like @@ -419,9 +442,9 @@ def csr(values, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, should always be zero. indices: array_like An object exposing the array interface, with shape [nnz]. - ctx : Context, optional + ctx: Context, optional Device context (default is the current default context). - dtype : str or numpy.dtype, optional + dtype: str or numpy.dtype, optional The data type of the output array. The default dtype is ``values.dtype`` if `values` is an `NDArray`, `float32` otherwise. indptr_type: str or numpy.dtype, optional @@ -435,13 +458,23 @@ def csr(values, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, ------- CSRNDArray A `CSRNDArray` with the `csr` storage representation. + + Example + ------- + >>> import mxnet as mx + >>> a = mx.nd.csr([1, 2, 3], [0, 1, 2, 2, 3], [1, 0, 2], (4, 3)) + >>> a.asnumpy() + array([[ 0., 1., 0.], + [ 2., 0., 0.], + [ 0., 0., 0.], + [ 0., 0., 3.]], dtype=float32) """ storage_type = 'csr' # context if ctx is None: ctx = Context.default_ctx # prepare src array and types - values, dtype = _prepare_src_array(values, dtype, mx_real_t) + data, dtype = _prepare_src_array(data, dtype, mx_real_t) indptr, indptr_type = _prepare_src_array(indptr, indptr_type, _STORAGE_AUX_TYPES[storage_type][0]) indices, indices_type = _prepare_src_array(indices, indices_type, @@ -451,17 +484,17 @@ def csr(values, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, assert('int64' in str(indices_type)), "expected int64 for indices" # verify shapes aux_shapes = [indptr.shape, indices.shape] - assert(values.ndim == 1) + assert(data.ndim == 1) assert(indptr.ndim == 1) assert(indices.ndim == 1) assert(len(shape) == 2) result = CSRNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype, [indptr_type, indices_type], aux_shapes)) - # assign indptr, indices and values - values_ref = result._data(True) + # assign indptr, indices and data + data_ref = result._data(True) indptr_ref = result._aux_data(0, True) indices_ref = result._aux_data(1, True) - values_ref[:] = values + data_ref[:] = data indptr_ref[:] = indptr indices_ref[:] = indices return result @@ -490,6 +523,17 @@ def row_sparse(values, indices, shape, ctx=None, dtype=None, indices_type=None): ------- RowSparseNDArray An `RowSparseNDArray` with the `row_sparse` storage representation. + + Example + ------- + >>> a = mx.nd.row_sparse([[1, 2], [3, 4]], [1, 4], (6, 2)) + >>> a.asnumpy() + array([[ 0., 0.], + [ 1., 2.], + [ 0., 0.], + [ 0., 0.], + [ 3., 4.], + [ 0., 0.]], dtype=float32) """ storage_type = 'row_sparse' # context @@ -522,20 +566,63 @@ def todense(source): NDArray The dense array with default storage """ - return ndarray.cast_storage(source, storage_type='default') + return cast_storage(source, stype='default') def _ndarray_cls(handle, writable=True, stype=None): if stype is None: stype = _storage_type(handle) if stype == 'default': - return NDArray(handle, writable=writable, stype=stype) + return NDArray(handle, writable=writable) elif stype == 'csr': - return CSRNDArray(handle, writable=writable, stype=stype) + return CSRNDArray(handle, writable=writable) elif stype == 'row_sparse': - return RowSparseNDArray(handle, writable=writable, stype=stype) + return RowSparseNDArray(handle, writable=writable) else: raise Exception("unknown storage type") _set_ndarray_class(_ndarray_cls) + + +def _zeros_sparse_ndarray(stype, shape, ctx=None, dtype=None, aux_types=None, **kwargs): + """Return a new array of given shape and type, filled with zeros. + + Parameters + ---------- + shape : int or tuple of int + The shape of the empty array + stype: string + The storage type of the empty array, such as 'row_sparse', 'csr', etc + ctx : Context, optional + An optional device context (default is the current default context) + dtype : str or numpy.dtype, optional + An optional value type (default is `float32`) + aux_types: list of numpy.dtype, optional + An optional type for the aux data for SparseNDArray (default values depends + on the storage type) + + Returns + ------- + SparseNDArray + A created array + Examples + -------- + >>> mx.nd.zeros('csr', (1,2), mx.gpu(0)) + + >>> mx.nd.zeros('row_sparse', (1,2), mx.gpu(0), 'float16').asnumpy() + array([[ 0., 0.]], dtype=float16) + """ + if stype == 'default': + return _zeros_ndarray(shape, ctx=ctx, dtype=dtype, **kwargs) + if ctx is None: + ctx = Context.default_ctx + dtype = mx_real_t if dtype is None else dtype + if aux_types is None: + if stype == 'row_sparse' or stype == 'csr': + aux_types = _STORAGE_AUX_TYPES[stype] + else: + raise Exception("unknown storage type") + assert(len(aux_types) == len(_STORAGE_AUX_TYPES[stype])) + out = _ndarray_cls(_new_alloc_handle(stype, shape, ctx, True, dtype, aux_types)) + return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out, **kwargs) diff --git a/python/mxnet/ndarray_utils.py b/python/mxnet/ndarray_utils.py deleted file mode 100644 index 5f8fa6c7bfb7..000000000000 --- a/python/mxnet/ndarray_utils.py +++ /dev/null @@ -1,198 +0,0 @@ -# coding: utf-8 -"""Utility functions for NDArray and SparseNDArray.""" -import ctypes -import sys as _sys - -from mxnet import Context -from mxnet.base import mx_real_t, _LIB, check_call, py_str, c_str, string_types, mx_uint,\ - NDArrayHandle, c_array -from mxnet.ndarray import NDArray -from mxnet.sparse_ndarray import _STORAGE_AUX_TYPES, _new_alloc_handle, _ndarray_cls -from . import _ndarray_internal as _internal - - -def _zeros_ndarray(shape, ctx=None, dtype=None, **kwargs): - """Returns a new array filled with all zeros, with the given shape and type. - - Parameters - ---------- - shape : int or tuple of int - The shape of the empty array. - ctx : Context, optional - An optional device context (default is the current default context). - dtype : str or numpy.dtype, optional - An optional value type (default is `float32`). - out : NDArray, optional - The output NDArray (default is `None`). - - Returns - ------- - NDArray - A created array - - Examples - -------- - >>> mx.nd.zeros(1).asnumpy() - array([ 0.], dtype=float32) - >>> mx.nd.zeros((1,2), mx.gpu(0)) - - >>> mx.nd.zeros((1,2), mx.gpu(0), 'float16').asnumpy() - array([[ 0., 0.]], dtype=float16) - """ - # pylint: disable= unused-argument - if ctx is None: - ctx = Context.default_ctx - dtype = mx_real_t if dtype is None else dtype - # pylint: disable= no-member, protected-access - return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) - # pylint: enable= no-member, protected-access - - -def _zeros_sparse_ndarray(storage_type, shape, ctx=None, dtype=None, aux_types=None, **kwargs): - """Return a new array of given shape and type, filled with zeros. - - Parameters - ---------- - shape : int or tuple of int - The shape of the empty array - storage_type: string - The storage type of the empty array, such as 'row_sparse', 'csr', etc - ctx : Context, optional - An optional device context (default is the current default context) - dtype : str or numpy.dtype, optional - An optional value type (default is `float32`) - aux_types: list of numpy.dtype, optional - An optional type for the aux data for SparseNDArray (default values depends - on the storage type) - - Returns - ------- - SparseNDArray - A created array - Examples - -------- - >>> mx.sparse_nd.zeros('csr', (1,2), mx.gpu(0)) - - >>> mx.sparse_nd.zeros('row_sparse', (1,2), mx.gpu(0), 'float16').asnumpy() - array([[ 0., 0.]], dtype=float16) - """ - if storage_type == 'default': - return _zeros_ndarray(shape, ctx=ctx, dtype=dtype, **kwargs) - if ctx is None: - ctx = Context.default_ctx - dtype = mx_real_t if dtype is None else dtype - if aux_types is None: - if storage_type == 'row_sparse' or storage_type == 'csr': - aux_types = _STORAGE_AUX_TYPES[storage_type] - else: - raise Exception("unknown storage type") - assert(len(aux_types) == len(_STORAGE_AUX_TYPES[storage_type])) - out = _ndarray_cls(_new_alloc_handle(storage_type, shape, ctx, True, dtype, aux_types)) - return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out, **kwargs) - - -def zeros(shape, ctx=None, dtype=None, storage_type=None, aux_types=None, **kwargs): - if storage_type is None: - return _zeros_ndarray(shape, ctx, dtype, **kwargs) - else: - return _zeros_sparse_ndarray(storage_type, shape, ctx, dtype, aux_types, **kwargs) - - -def load(fname): - """Loads an array from file. - - See more details in ``save``. - - Parameters - ---------- - fname : str - The filename. - - Returns - ------- - list of NDArray or dict of str to NDArray - Loaded data. - """ - if not isinstance(fname, string_types): - raise TypeError('fname required to be a string') - out_size = mx_uint() - out_name_size = mx_uint() - handles = ctypes.POINTER(NDArrayHandle)() - names = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.MXNDArrayLoad(c_str(fname), - ctypes.byref(out_size), - ctypes.byref(handles), - ctypes.byref(out_name_size), - ctypes.byref(names))) - if out_name_size.value == 0: - return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)] - else: - assert out_name_size.value == out_size.value - return dict( - (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i]))) - for i in range(out_size.value)) - - -def save(fname, data): - """Saves a list of arrays or a dict of str->array to file. - - Examples of filenames: - - - ``/path/to/file`` - - ``s3://my-bucket/path/to/file`` (if compiled with AWS S3 supports) - - ``hdfs://path/to/file`` (if compiled with HDFS supports) - - Parameters - ---------- - fname : str - The filename. - data : list of ``NDArray` or dict of str to ``NDArray`` - The data to save. - - Examples - -------- - >>> x = mx.nd.zeros((2,3)) - >>> y = mx.nd.ones((1,4)) - >>> mx.nd.save('my_list', [x,y]) - >>> mx.nd.save('my_dict', {'x':x, 'y':y}) - >>> mx.nd.load('my_list') - [, ] - >>> mx.nd.load('my_dict') - {'y': , 'x': } - """ - handles = [] - if isinstance(data, dict): - keys = [] - for key, val in data.items(): - if not isinstance(key, string_types): - raise TypeError('save only accept dict str->NDArray or list of NDArray') - if not isinstance(val, NDArray): - raise TypeError('save only accept dict str->NDArray or list of NDArray') - keys.append(c_str(key)) - handles.append(val.handle) - keys = c_array(ctypes.c_char_p, keys) - else: - for val in data: - if not isinstance(val, NDArray): - raise TypeError('save only accept dict str->NDArray or list of NDArray') - handles.append(val.handle) - keys = None - check_call(_LIB.MXNDArraySave(c_str(fname), - mx_uint(len(handles)), - c_array(NDArrayHandle, handles), - keys)) - - -def _init_ndarray_module_frontend(function, root_namespace, module_name): - """Register front end functions defined in this file to mxnet.ndarray module. - The functions registered were originally defined in mxnet.ndarray. They were - moved here because they need to know SparseNDArray class, while it's not allowed - in ndarray.py since that would result in circular import.""" - module_obj = _sys.modules["%s.%s" % (root_namespace, module_name)] - setattr(module_obj, function.__name__, function) - - -# register the following front end functions in mx.nd -_init_ndarray_module_frontend(zeros, "mxnet", "ndarray") -_init_ndarray_module_frontend(load, "mxnet", "ndarray") -_init_ndarray_module_frontend(save, "mxnet", "ndarray") diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 7e46c30c7c79..ebbf70101353 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -2,10 +2,8 @@ import math import pickle import logging -import mxnet as mx -from .ndarray import NDArray, clip, sqrt, sign +from .ndarray import NDArray, clip, sqrt, sign, zeros from .ndarray import sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update -from .ndarray_utils import zeros from .random import normal @@ -334,8 +332,8 @@ def create_state(self, index, weight): if self.momentum == 0.0: return None else: - return mx.nd.zeros(shape=weight.shape, ctx=weight.context, - dtype=weight.dtype, storage_type=weight.stype) + return zeros(shape=weight.shape, ctx=weight.context, + dtype=weight.dtype, stype=weight.stype) def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) @@ -513,8 +511,8 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, self.epsilon = epsilon def create_state(self, index, weight): - return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype), # mean - mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance + return (zeros(weight.shape, weight.context, dtype=weight.dtype), # mean + zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) @@ -619,11 +617,11 @@ def __init__(self, learning_rate=0.001, gamma1=0.9, gamma2=0.9, def create_state(self, index, weight): if self.centered: return ( - mx.nd.zeros(weight.shape, weight.context), # n - mx.nd.zeros(weight.shape, weight.context), # g - mx.nd.zeros(weight.shape, weight.context)) # delta + zeros(weight.shape, weight.context), # n + zeros(weight.shape, weight.context), # g + zeros(weight.shape, weight.context)) # delta else: - return (mx.nd.zeros(weight.shape, weight.context), ) # n + return (zeros(weight.shape, weight.context),) # n def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) diff --git a/python/mxnet/random.py b/python/mxnet/random.py index 91c2f5035ffa..5707632c83c1 100644 --- a/python/mxnet/random.py +++ b/python/mxnet/random.py @@ -5,13 +5,13 @@ import ctypes from .base import _LIB, check_call -from ._ndarray_internal import _sample_uniform as uniform -from ._ndarray_internal import _sample_normal as normal -from ._ndarray_internal import _sample_gamma as gamma -from ._ndarray_internal import _sample_exponential as exponential -from ._ndarray_internal import _sample_poisson as poisson -from ._ndarray_internal import _sample_negbinomial as negative_binomial -from ._ndarray_internal import _sample_gennegbinomial as generalized_negative_binomial +from .ndarray._internal import _sample_uniform as uniform +from .ndarray._internal import _sample_normal as normal +from .ndarray._internal import _sample_gamma as gamma +from .ndarray._internal import _sample_exponential as exponential +from .ndarray._internal import _sample_poisson as poisson +from .ndarray._internal import _sample_negbinomial as negative_binomial +from .ndarray._internal import _sample_gennegbinomial as generalized_negative_binomial def seed(seed_state): """Seeds the random number generators in MXNet. diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 796ca77eaa13..6519802b1535 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -17,10 +17,10 @@ from .base import NDArrayHandle, ExecutorHandle, SymbolHandle, OpHandle from .base import check_call, MXNetError, _Null # pylint: disable=unused-import from .context import Context, cpu -from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP +from .ndarray.ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP from .name import NameManager # pylint: disable=unused-import -from .ndarray import _STORAGE_TYPE_STR_TO_ID -from .sparse_ndarray import _ndarray_cls +from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID +from .ndarray.sparse_ndarray import _ndarray_cls from .executor import Executor from . import _symbol_internal as _internal from .attribute import AttrScope @@ -1162,7 +1162,7 @@ def _get_ndarray_inputs(arg_key, args, arg_names, allow_missing): raise TypeError('Only accept list of NDArrays or dict of str to NDArray') return c_array(NDArrayHandle, arg_handles), arg_arrays - def simple_bind(self, ctx, grad_req='write', type_dict=None, storage_type_dict=None, + def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None, group2ctx=None, shared_arg_names=None, shared_exec=None, shared_buffer=None, **kwargs): """Bind current symbol to get an executor, allocate all the arguments needed. @@ -1206,7 +1206,7 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, storage_type_dict=N type_dict : Dict of str->numpy.dtype Input type dictionary, name->dtype - storage_type_dict : Dict of str->str + stype_dict : Dict of str->str Input storage type dictionary, name->storage_type group2ctx : Dict of string to mx.Context @@ -1255,10 +1255,10 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, storage_type_dict=N # provided storage type argument names provided_arg_stype_names = ctypes.POINTER(ctypes.c_char_p)() provided_arg_stype_data = ctypes.POINTER(mx_uint)() # provided storage types - if storage_type_dict is not None: + if stype_dict is not None: provided_arg_stype_names = [] provided_arg_stype_data = [] - for k, v in storage_type_dict.items(): + for k, v in stype_dict.items(): if v in _STORAGE_TYPE_STR_TO_ID: provided_arg_stype_names.append(c_str(k)) provided_arg_stype_data.append(ctypes.c_int(_STORAGE_TYPE_STR_TO_ID[v])) @@ -1339,7 +1339,7 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, storage_type_dict=N shared_buffer_names = [] shared_buffer_handles = [] for k, v in shared_buffer.items(): - assert(v.storage_type == 'default'), \ + assert(v.stype == 'default'), \ "shared_buffer is expected to only contain NDArrays with default storage" shared_buffer_names.append(c_str(k)) shared_buffer_handles.append(v.handle) @@ -1669,7 +1669,7 @@ def reshape(self, shape): return reshape(self, shape=shape) def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, - init=None, storage_type=None, **kwargs): + init=None, stype=None, **kwargs): """Creates a symbolic variable with specified name. Example usage: @@ -1696,6 +1696,8 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, The dtype for input variable. If not specified, this value will be inferred. init : initializer (mxnet.init.*) Initializer for this variable to (optionally) override the default initializer. + stype : str + The storage type of the variable. kwargs : Additional attribute variables Additional attributes must start and end with double underscores. @@ -1723,8 +1725,8 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, if not isinstance(init, string_types): init = init.dumps() attr['__init__'] = init - if storage_type is not None: - attr['__storage_type__'] = str(_STORAGE_TYPE_STR_TO_ID[storage_type]) + if stype is not None: + attr['__storage_type__'] = str(_STORAGE_TYPE_STR_TO_ID[stype]) for k, v in kwargs.items(): if k.startswith('__') and k.endswith('__'): attr[k] = str(v) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index d860b531e520..ded0f65ebb0e 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -21,7 +21,7 @@ pass import mxnet as mx from .context import Context -from .ndarray import array, _STORAGE_TYPE_STR_TO_ID +from .ndarray.ndarray import array, _STORAGE_TYPE_STR_TO_ID from .symbol import Symbol _rng = np.random.RandomState(1234) @@ -76,10 +76,10 @@ def random_sample(population, k): return population_copy[0:k] -def rand_sparse_ndarray(shape, storage_type, density=None): +def rand_sparse_ndarray(shape, stype, density=None): """Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np) """ density = rnd.rand() if density is None else density - if storage_type == 'row_sparse': + if stype == 'row_sparse': # TODO(haibin) support high dim sparse ndarray assert(len(shape) < 3) prod = np.prod(shape) @@ -88,26 +88,26 @@ def rand_sparse_ndarray(shape, storage_type, density=None): idx_sample = rnd.rand(shape[0]) indices = np.argwhere(idx_sample < density).flatten() if indices.shape[0] == 0: - result = mx.nd.zeros(shape, storage_type='row_sparse') + result = mx.nd.zeros(shape, stype='row_sparse') return result, (np.array([], dtype='int64'), np.array([], dtype='int64')) # generate random values val = rnd.rand(indices.shape[0], num_cols) - arr = mx.sparse_nd.row_sparse(val, indices, shape, indices_type=np.int64) + arr = mx.nd.row_sparse(val, indices, shape, indices_type=np.int64) return arr, (val, indices) - elif storage_type == 'csr': + elif stype == 'csr': assert(len(shape) == 2) csr = sp.rand(shape[0], shape[1], density=density, format='csr') - result = mx.sparse_nd.csr(csr.data, csr.indptr, csr.indices, shape) + result = mx.nd.csr(csr.data, csr.indptr, csr.indices, shape) return result, (csr.indptr, csr.indices, csr.data) else: assert(False), "unknown storage type" -def rand_ndarray(shape, storage_type, density=None): - if storage_type == 'default': +def rand_ndarray(shape, stype, density=None): + if stype == 'default': arr = mx.nd.array(random_arrays(shape)) else: - arr, _ = rand_sparse_ndarray(shape, storage_type, density=density) + arr, _ = rand_sparse_ndarray(shape, stype, density=density) return arr @@ -554,7 +554,7 @@ def random_projection(shape): assert isinstance(grad_stype_dict, dict), "grad_stype_dict must be a dict" for k, v in grad_stype_dict.items(): if k in args_grad and v in _STORAGE_TYPE_STR_TO_ID and v != 'default': - args_grad[k] = mx.nd.cast_storage(args_grad[k], storage_type=v) + args_grad[k] = mx.nd.cast_storage(args_grad[k], stype=v) executor = out.bind(ctx, grad_req=grad_req, args=location, args_grad=args_grad, aux_states=aux_states) @@ -654,7 +654,7 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=None, - aux_states=None, grad_req='write', ctx=None): + aux_states=None, grad_req='write', ctx=None, grad_stypes=None): """Compares a symbol's backward results with the expected ones. Prints error messages if the backward results are not the same as the expected results. @@ -690,6 +690,8 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= Gradient requirements. 'write', 'add' or 'null'. ctx : Context, optional Running context. + grad_stypes: dict of str->str + dictionary of mapping argument name to stype for the gradient Example ------- @@ -715,16 +717,11 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= if isinstance(expected, (list, tuple)): expected = {k:v for k, v in zip(sym.list_arguments(), expected)} args_grad_npy = {k:_rng.normal(size=v.shape) for k, v in expected.items()} - # args_grad_data should be casted to storage type if hinted - # TODO(haibin) this is a temporary solution for testing. remove later - attrs = sym.attr_dict() args_grad_data = {} for k, v in args_grad_npy.items(): - attr = attrs.get(k, {}) - grad_stype = attr.get('grad_stype_hint', None) nd = mx.nd.array(v, ctx=ctx) - if grad_stype is not None: - out = mx.nd.cast_storage(nd, storage_type=grad_stype) + if grad_stypes is not None and k in grad_stypes: + out = mx.nd.cast_storage(nd, stype=grad_stypes[k]) args_grad_data[k] = out else: args_grad_data[k] = nd diff --git a/python/setup.py b/python/setup.py index 8a8693038b3c..1f3abf536f86 100644 --- a/python/setup.py +++ b/python/setup.py @@ -74,7 +74,7 @@ def config_cython(): version=__version__, description=open(os.path.join(CURRENT_DIR, 'README.md')).read(), packages=[ - 'mxnet', 'mxnet.module', 'mxnet._ctypes', 'mxnet.rnn', + 'mxnet', 'mxnet.module', 'mxnet._ctypes', 'mxnet.rnn', 'mxnet.ndarray', 'mxnet._cy2', 'mxnet._cy3', 'mxnet.notebook', 'mxnet.contrib' ], data_files=[('mxnet', [LIB_PATH[0]])], diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 615b0231d750..f6ec268d1512 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -489,14 +489,28 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, NDArrayHandle **outputs, int num_params, const char **param_keys, - const char **param_vals, - const int** out_stypes) { // outputs storage types + const char **param_vals) { const nnvm::Op* op = static_cast(creator); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); nnvm::NodeAttrs attrs; SetOpAttrs(op, &attrs, num_inputs, num_params, param_keys, param_vals); ImperativeInvokeImpl(attrs, num_inputs, inputs, num_outputs, outputs); + API_END(); +} + +int MXImperativeInvokeEx(AtomicSymbolCreator creator, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs, + int num_params, + const char **param_keys, + const char **param_vals, + const int **out_stypes) { // outputs storage types + API_BEGIN(); + MXImperativeInvoke(creator, num_inputs, inputs, num_outputs, outputs, + num_params, param_keys, param_vals); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); NDArray** output_nds = reinterpret_cast(*outputs); ret->out_types.resize(*num_outputs); for (int i = 0; i < *num_outputs; ++i) { diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 218781f5bbf7..3d522c83efac 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -12,7 +12,7 @@ #include #include #include "mxnet/ndarray.h" -#include "../common/utils.h" +#include "../ndarray/ndarray_function.h" namespace mxnet { namespace kvstore { /** @@ -142,7 +142,8 @@ class CommCPU : public Comm { Engine::Get()->PushSync([reduce, result, this](RunContext rctx) { NDArray out = result; is_serial_push_? - ReduceSumCPUExSerial(reduce, &out) : ReduceSumCPUExParallel(reduce, &out); + ReduceSumCPUExSerial(reduce, &out) + : mxnet::ndarray::ElementwiseSum(rctx.get_stream(), reduce, &out); }, Context::CPU(), const_vars, {result.var()}, FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); } @@ -251,115 +252,6 @@ class CommCPU : public Comm { }); } - template - void ReduceSumCPUExImpl(const std::vector& nds, - const std::vector& uniq_row_idx, - NDArray* out) { -#pragma omp parallel num_threads(nthread_reduction_) - { - const size_t nnr = uniq_row_idx.size(); - const int num_threads = omp_get_num_threads(); - size_t row_block_len = (nnr + num_threads - 1) / num_threads; - const size_t row_block_start = omp_get_thread_num() * row_block_len; - if (row_block_start < nnr) { - const size_t row_block_end = std::min(row_block_start+row_block_len, nnr); - - auto out_values = out->data().FlatTo2D(); - auto out_indices = out->aux_data(rowsparse::kIdx).FlatTo1D(); - for (size_t i = row_block_start; i < row_block_end; ++i) { - out_indices[i] = uniq_row_idx[i]; - } - for (const auto& nd : nds) { - if (nd.storage_initialized()) { - const auto nd_indices = nd.aux_data(rowsparse::kIdx).FlatTo1D(); - const auto nd_values = nd.data().FlatTo2D(); - const auto nd_num_rows = nd.aux_shape(rowsparse::kIdx).Size(); - const IType* nd_indices_start = &nd_indices[0]; - const IType* nd_indices_end = nd_indices_start + nd_num_rows; - const IType* row_idx_ptr = std::lower_bound(nd_indices_start, nd_indices_end, - out_indices[row_block_start]); - // skip this nd if all of its row indices are smaller than out_indices[row_block_start] - // or current row block is not covered by [*row_idx_ptr, nd_indices_end). - if (nd_indices_end == row_idx_ptr || *row_idx_ptr > out_indices[row_block_end-1]) { - continue; - } - for (size_t irow = row_block_start; - irow < row_block_end && row_idx_ptr != nd_indices_end;) { - if (out_indices[irow] == *row_idx_ptr) { - auto out_value_cur_row = out_values[irow]; - const auto offset = row_idx_ptr - nd_indices_start; - auto nd_value_cur_row = nd_values[offset]; - for (size_t j = 0; j < nd_value_cur_row.shape_[0]; ++j) { - out_value_cur_row[j] += nd_value_cur_row[j]; - } - ++irow; - ++row_idx_ptr; - } else if (out_indices[irow] < *row_idx_ptr) { - ++irow; - } else { - ++row_idx_ptr; - } - } - } - } - } - } - } - - /*! - * \brief Given a vector of ndarrays, generate a index vector containing - * all the unique row indices of the ndarrays. - */ - template - void GetUniqueRspRowIdx(const std::vector& nds, - std::vector* uniq_row_idx) { - using namespace rowsparse; - size_t total_num_rows = 0; - for (const auto& nd : nds) { - CHECK_EQ(nd.storage_type(), kRowSparseStorage); - if (nd.storage_initialized()) { - total_num_rows += nd.aux_shape(kIdx).Size(); - } - } - - uniq_row_idx->resize(total_num_rows); - int nthreads = omp_get_max_threads(); - int offset = 0; - for (const auto& nd : nds) { - if (nd.storage_initialized()) { - const IType* nd_row_idx = nd.aux_data(kIdx).dptr(); - const int num_rows = nd.aux_shape(kIdx).Size(); -#pragma omp parallel for num_threads(nthreads) - for (int i = 0; i < num_rows; ++i) { - (*uniq_row_idx)[offset+i] = nd_row_idx[i]; - } - offset += num_rows; - } - } - - common::ParallelSort(uniq_row_idx->begin(), uniq_row_idx->end(), nthreads); - auto it = std::unique(uniq_row_idx->begin(), uniq_row_idx->end()); - uniq_row_idx->resize(it - uniq_row_idx->begin()); - } - - void ReduceSumCPUExParallel(const std::vector& nds, NDArray* out) { - if (nds.empty()) return; - using namespace rowsparse; - CHECK_EQ(out->storage_type(), kRowSparseStorage) - << "Expected row sparse storage type (" - << out->storage_type() << " given)"; - - MSHADOW_TYPE_SWITCH(out->dtype(), DType, { - MSHADOW_IDX_TYPE_SWITCH(out->aux_type(kIdx), IType, { - std::vector uniq_row_idx; - GetUniqueRspRowIdx(nds, &uniq_row_idx); - out->CheckAndAlloc({mshadow::Shape1(uniq_row_idx.size())}); - out->data().FlatTo2D() = static_cast(0); - ReduceSumCPUExImpl(nds, uniq_row_idx, out); - }); - }); - } - template inline static void ReduceSumCPU( const std::vector &dptr, size_t offset, index_t size) { diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc index a5ba2660fd34..b03166f4d834 100644 --- a/src/ndarray/ndarray_function.cc +++ b/src/ndarray/ndarray_function.cc @@ -7,6 +7,7 @@ // this will be invoked by gcc and compile CPU version #include "./ndarray_function.h" #include "./ndarray_function-inl.h" +#include "../common/utils.h" namespace mxnet { namespace ndarray { @@ -26,5 +27,134 @@ void Copy(const TBlob &from, TBlob *to, } }) } + +template +void ElementwiseSumRspImpl(const std::vector& nds, + const std::vector& uniq_row_idx, + NDArray* out, + const int nthreads = 4) { +#pragma omp parallel num_threads(nthreads) + { + const size_t nnr = uniq_row_idx.size(); + const int num_threads = omp_get_num_threads(); + size_t row_block_len = (nnr + num_threads - 1) / num_threads; + const size_t row_block_start = omp_get_thread_num() * row_block_len; + if (row_block_start < nnr) { + const size_t row_block_end = std::min(row_block_start+row_block_len, nnr); + + auto out_values = out->data().FlatTo2D(); + auto out_indices = out->aux_data(rowsparse::kIdx).FlatTo1D(); + for (size_t i = row_block_start; i < row_block_end; ++i) { + out_indices[i] = uniq_row_idx[i]; + } + for (const auto& nd : nds) { + if (nd.storage_initialized()) { + const auto nd_indices = nd.aux_data(rowsparse::kIdx).FlatTo1D(); + const auto nd_values = nd.data().FlatTo2D(); + const auto nd_num_rows = nd.aux_shape(rowsparse::kIdx).Size(); + const IType* nd_indices_start = &nd_indices[0]; + const IType* nd_indices_end = nd_indices_start + nd_num_rows; + const IType* row_idx_ptr = std::lower_bound(nd_indices_start, nd_indices_end, + out_indices[row_block_start]); + // skip this nd if all of its row indices are smaller than out_indices[row_block_start] + // or current row block is not covered by [*row_idx_ptr, nd_indices_end). + if (nd_indices_end == row_idx_ptr || *row_idx_ptr > out_indices[row_block_end-1]) { + continue; + } + for (size_t irow = row_block_start; + irow < row_block_end && row_idx_ptr != nd_indices_end;) { + if (out_indices[irow] == *row_idx_ptr) { + auto out_value_cur_row = out_values[irow]; + const auto offset = row_idx_ptr - nd_indices_start; + auto nd_value_cur_row = nd_values[offset]; + for (size_t j = 0; j < nd_value_cur_row.shape_[0]; ++j) { + out_value_cur_row[j] += nd_value_cur_row[j]; + } + ++irow; + ++row_idx_ptr; + } else if (out_indices[irow] < *row_idx_ptr) { + ++irow; + } else { + ++row_idx_ptr; + } + } + } + } + } + } +} + +/*! + * \brief Given a vector of ndarrays, generate a index vector containing + * all the unique row indices of the ndarrays. + */ +template +void GetUniqueRspRowIdx(const std::vector& nds, + std::vector* uniq_row_idx) { + using namespace rowsparse; + size_t total_num_rows = 0; + for (const auto& nd : nds) { + CHECK_EQ(nd.storage_type(), kRowSparseStorage); + if (nd.storage_initialized()) { + total_num_rows += nd.aux_shape(kIdx).Size(); + } + } + + uniq_row_idx->resize(total_num_rows); + int nthreads = omp_get_max_threads(); + int offset = 0; + for (const auto& nd : nds) { + if (nd.storage_initialized()) { + const IType* nd_row_idx = nd.aux_data(kIdx).dptr(); + const int num_rows = nd.aux_shape(kIdx).Size(); +#pragma omp parallel for num_threads(nthreads) + for (int i = 0; i < num_rows; ++i) { + (*uniq_row_idx)[offset+i] = nd_row_idx[i]; + } + offset += num_rows; + } + } + + common::ParallelSort(uniq_row_idx->begin(), uniq_row_idx->end(), nthreads); + auto it = std::unique(uniq_row_idx->begin(), uniq_row_idx->end()); + uniq_row_idx->resize(it - uniq_row_idx->begin()); +} + +void ElementwiseSumRsp(const std::vector& nds, NDArray* out) { + if (nds.empty()) return; + using namespace rowsparse; + CHECK_EQ(out->storage_type(), kRowSparseStorage) + << "Expected row sparse storage type (" + << out->storage_type() << " given)"; + + MSHADOW_TYPE_SWITCH(out->dtype(), DType, { + MSHADOW_IDX_TYPE_SWITCH(out->aux_type(kIdx), IType, { + std::vector uniq_row_idx; + GetUniqueRspRowIdx(nds, &uniq_row_idx); + out->CheckAndAlloc({mshadow::Shape1(uniq_row_idx.size())}); + out->data().FlatTo2D() = static_cast(0); + ElementwiseSumRspImpl(nds, uniq_row_idx, out, omp_get_max_threads()); + }); + }); +} + +/*! + * \brief Parallel cpu impl of elemwise sum for sparse tensors. + * Currently only support row sparse sum. + */ +template<> +void ElementwiseSum(mshadow::Stream* s, + const std::vector& nds, + NDArray* out) { + if (nds.empty()) return; + + if (nds[0].storage_type() == kRowSparseStorage) { + ElementwiseSumRsp(nds, out); + } else { + LOG(FATAL) << "ElementwiseSum has not been implemented for storage_type = << " + << nds[0].storage_type(); + } +} + } // namespace ndarray } // namespace mxnet diff --git a/src/ndarray/ndarray_function.h b/src/ndarray/ndarray_function.h index 00dd3d0e959a..cb69721b7ccf 100644 --- a/src/ndarray/ndarray_function.h +++ b/src/ndarray/ndarray_function.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include "../operator/mshadow_op.h" @@ -146,6 +147,14 @@ void ElementwiseSum(const std::vector source, TBlob *out, RunContext ctx); +/*! + * \brief Interface for parallel impl of elemwise sum for sparse matrices + */ +template +void ElementwiseSum(mshadow::Stream* s, + const std::vector& nds, + NDArray* out); + // broadcasting template void EvalBroadcast(TBlob const& src, TBlob* ret, int size, RunContext ctx); diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h index 9273b996d48e..da9ed30b998a 100644 --- a/src/operator/tensor/cast_storage-inl.h +++ b/src/operator/tensor/cast_storage-inl.h @@ -291,9 +291,9 @@ void CastStorageComputeImpl(mshadow::Stream* s, } struct CastStorageParam : public dmlc::Parameter { - int storage_type; + int stype; DMLC_DECLARE_PARAMETER(CastStorageParam) { - DMLC_DECLARE_FIELD(storage_type) + DMLC_DECLARE_FIELD(stype) .add_enum("default", kDefaultStorage) .add_enum("row_sparse", kRowSparseStorage) .add_enum("csr", kCSRStorage) @@ -310,9 +310,9 @@ inline bool CastStorageInferStorageType(const nnvm::NodeAttrs& attrs, CHECK_NE(in_attrs->at(0), kUndefinedStorage) << "src ndarray's storage type must be specified"; const CastStorageParam& param = nnvm::get(attrs.parsed); - CHECK_NE(param.storage_type, kUndefinedStorage) + CHECK_NE(param.stype, kUndefinedStorage) << "dst ndarray's storage type must be specified"; - TYPE_ASSIGN_CHECK(*out_attrs, 0, param.storage_type); + TYPE_ASSIGN_CHECK(*out_attrs, 0, param.stype); return true; } diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 46aa6fcd73a4..6e4b380f893b 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -806,9 +806,8 @@ inline bool SparseRetainForwardInferStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); - if (kRowSparseStorage == in_attrs->at(sr::kArr)) { - out_attrs->at(sr::kOut) = kRowSparseStorage; - } + STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kArr, kRowSparseStorage); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, sr::kOut, kRowSparseStorage); return true; } @@ -818,8 +817,10 @@ inline bool SparseRetainBackwardInferStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 2U); - out_attrs->at(sr::kArr) = kRowSparseStorage; - out_attrs->at(sr::kIdx) = kDefaultStorage; + STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kOut, kDefaultStorage); + STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kIdx, kDefaultStorage); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, sr::kArr, kRowSparseStorage); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, sr::kIdx, kDefaultStorage); return true; } diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index bd12f95b2496..1489b8687c26 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -11,9 +11,9 @@ def init_kv(stype='default'): """init kv """ kv = mx.kv.create() # single - kv.init(3, mx.nd.zeros(shape=shape, storage_type=stype)) + kv.init(3, mx.nd.zeros(shape=shape, stype=stype)) # list - kv.init(keys, [mx.nd.zeros(shape=shape, storage_type=stype)] * len(keys)) + kv.init(keys, [mx.nd.zeros(shape=shape, stype=stype)] * len(keys)) return kv def init_kv_with_str(): diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 96fd77334d8d..1c65f676955f 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -380,10 +380,10 @@ def test_module_fm(): rnd.seed(11) def fm_model(k, feature_dim): norm = mx.initializer.Normal(sigma=0.01) - x = mx.symbol.Variable("data", storage_type='csr') - v = mx.symbol.Variable("v", shape=(feature_dim, k), init=norm, storage_type='row_sparse') + x = mx.symbol.Variable("data", stype='csr') + v = mx.symbol.Variable("v", shape=(feature_dim, k), init=norm, stype='row_sparse') - w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=norm, storage_type='row_sparse') + w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=norm, stype='row_sparse') w1 = mx.symbol.dot(x, w1_weight) v_s = mx.symbol.sum(data=mx.symbol.square(data=v), axis=1) @@ -412,7 +412,7 @@ def fm_model(k, feature_dim): import scipy.sparse as sp # generate some random scipy csr data csr_sp = sp.rand(num_samples, feature_dim, density=0.5, format='csr') - csr_nd = mx.sparse_nd.csr(csr_sp.data, csr_sp.indptr, csr_sp.indices, + csr_nd = mx.nd.csr(csr_sp.data, csr_sp.indptr, csr_sp.indices, (num_samples, feature_dim)) label = mx.nd.ones((num_samples,1)) # the alternative is to use LibSVMIter @@ -443,9 +443,9 @@ def fm_model(k, feature_dim): def test_module_initializer(): def regression_model(m): - x = mx.symbol.var("data", storage_type='csr') + x = mx.symbol.var("data", stype='csr') v = mx.symbol.var("v", shape=(m, 1), init=mx.init.Uniform(scale=.1), - storage_type='row_sparse') + stype='row_sparse') model = mx.symbol.dot(lhs=x, rhs=v) y = mx.symbol.Variable("label") model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out") @@ -454,7 +454,7 @@ def regression_model(m): n, m = 128, 100 model = regression_model(m) - data = mx.nd.zeros(shape=(n, m), storage_type='csr') + data = mx.nd.zeros(shape=(n, m), stype='csr') label = mx.nd.zeros((n, 1)) iterator = mx.io.NDArrayIter(data=data, label={'label':label}, batch_size=n) diff --git a/tests/python/unittest/test_multi_device_exec.py b/tests/python/unittest/test_multi_device_exec.py index 3293ae2b0abc..9823036867d6 100644 --- a/tests/python/unittest/test_multi_device_exec.py +++ b/tests/python/unittest/test_multi_device_exec.py @@ -35,8 +35,8 @@ def test_ctx_group(): def check_ctx_group_sparse(lhs_stype, rhs_stype): with mx.AttrScope(ctx_group='stage1'): - lhs = mx.symbol.Variable('lhs', storage_type=lhs_stype) - rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype) + lhs = mx.symbol.Variable('lhs', stype=lhs_stype) + rhs = mx.symbol.Variable('rhs', stype=rhs_stype) plus = mx.symbol.elemwise_add(lhs, rhs, name='plus') set_stage1 = set(plus.list_arguments()) diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 66e13801cc30..1ba219c97aae 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -5,7 +5,7 @@ from numpy.testing import assert_allclose import numpy.random as rnd -from mxnet.sparse_ndarray import RowSparseNDArray, CSRNDArray, _ndarray_cls +from mxnet.ndarray import RowSparseNDArray, CSRNDArray def assert_fcompex(f, *args, **kwargs): @@ -15,13 +15,13 @@ def assert_fcompex(f, *args, **kwargs): def sparse_nd_ones(shape, stype): - return mx.nd.cast_storage(mx.nd.ones(shape), storage_type=stype) + return mx.nd.cast_storage(mx.nd.ones(shape), stype=stype) -def check_sparse_nd_elemwise_binary(shapes, storage_types, f, g): +def check_sparse_nd_elemwise_binary(shapes, stypes, f, g): # generate inputs nds = [] - for i, storage_type in enumerate(storage_types): + for i, storage_type in enumerate(stypes): if storage_type == 'row_sparse': nd, _ = rand_sparse_ndarray(shapes[i], storage_type) elif storage_type == 'default': @@ -63,7 +63,7 @@ def test_sparse_nd_elementwise_fallback(): def test_sparse_nd_zeros(): def check_sparse_nd_zeros(stype, shape): zero = mx.nd.zeros(shape) - sparse_zero = mx.nd.zeros(shape=shape, storage_type=stype) + sparse_zero = mx.nd.zeros(shape=shape, stype=stype) assert_almost_equal(sparse_zero.asnumpy(), zero.asnumpy()) shape = rand_shape_2d() @@ -102,22 +102,18 @@ def check_sparse_nd_prop_rsp(): def test_sparse_nd_basic(): def check_rsp_creation(values, indices, shape): - rsp = mx.sparse_nd.row_sparse(values, indices, shape) + rsp = mx.nd.row_sparse(values, indices, shape) dns = mx.nd.zeros(shape) dns[1] = mx.nd.array(values[0]) dns[3] = mx.nd.array(values[1]) - #assert_almost_equal(rsp.asnumpy(), dns.asnumpy()) - print('before', indices) - print('mx', mx.nd.array(indices, dtype='int64')[1].asnumpy()) indices_np = mx.nd.array(indices, dtype='int64').asnumpy() - print('after', indices_np) assert_almost_equal(rsp.indices.asnumpy(), indices_np) def check_csr_creation(shape): csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr') assert_almost_equal(csr.indptr.asnumpy(), indptr) assert_almost_equal(csr.indices.asnumpy(), indices) - assert_almost_equal(csr.values.asnumpy(), values) + assert_almost_equal(csr.data.asnumpy(), values) shape = (4,2) values = np.random.rand(2,2) @@ -137,8 +133,8 @@ def check_csr_creation(shape): def test_sparse_nd_setitem(): - def check_sparse_nd_setitem(storage_type, shape, dst): - x = mx.nd.zeros(shape=shape, storage_type=storage_type) + def check_sparse_nd_setitem(stype, shape, dst): + x = mx.nd.zeros(shape=shape, stype=stype) x[:] = dst dst_nd = mx.nd.array(dst) if isinstance(dst, (np.ndarray, np.generic)) else dst assert same(x.asnumpy(), dst_nd.asnumpy()) @@ -170,7 +166,7 @@ def check_sparse_nd_csr_slice(shape): def test_sparse_nd_equal(): for stype in ['row_sparse', 'csr']: shape = rand_shape_2d() - x = mx.nd.zeros(shape=shape, storage_type=stype) + x = mx.nd.zeros(shape=shape, stype=stype) y = sparse_nd_ones(shape, stype) z = x == y assert (z.asnumpy() == np.zeros(shape)).all() @@ -181,7 +177,7 @@ def test_sparse_nd_equal(): def test_sparse_nd_not_equal(): for stype in ['row_sparse', 'csr']: shape = rand_shape_2d() - x = mx.nd.zeros(shape=shape, storage_type=stype) + x = mx.nd.zeros(shape=shape, stype=stype) y = sparse_nd_ones(shape, stype) z = x != y assert (z.asnumpy() == np.ones(shape)).all() @@ -192,7 +188,7 @@ def test_sparse_nd_not_equal(): def test_sparse_nd_greater(): for stype in ['row_sparse', 'csr']: shape = rand_shape_2d() - x = mx.nd.zeros(shape=shape, storage_type=stype) + x = mx.nd.zeros(shape=shape, stype=stype) y = sparse_nd_ones(shape, stype) z = x > y assert (z.asnumpy() == np.zeros(shape)).all() @@ -205,7 +201,7 @@ def test_sparse_nd_greater(): def test_sparse_nd_greater_equal(): for stype in ['row_sparse', 'csr']: shape = rand_shape_2d() - x = mx.nd.zeros(shape=shape, storage_type=stype) + x = mx.nd.zeros(shape=shape, stype=stype) y = sparse_nd_ones(shape, stype) z = x >= y assert (z.asnumpy() == np.zeros(shape)).all() @@ -220,7 +216,7 @@ def test_sparse_nd_greater_equal(): def test_sparse_nd_lesser(): for stype in ['row_sparse', 'csr']: shape = rand_shape_2d() - x = mx.nd.zeros(shape=shape, storage_type=stype) + x = mx.nd.zeros(shape=shape, stype=stype) y = sparse_nd_ones(shape, stype) z = y < x assert (z.asnumpy() == np.zeros(shape)).all() @@ -233,7 +229,7 @@ def test_sparse_nd_lesser(): def test_sparse_nd_lesser_equal(): for stype in ['row_sparse', 'csr']: shape = rand_shape_2d() - x = mx.nd.zeros(shape=shape, storage_type=stype) + x = mx.nd.zeros(shape=shape, stype=stype) y = sparse_nd_ones(shape, stype) z = y <= x assert (z.asnumpy() == np.zeros(shape)).all() @@ -328,7 +324,7 @@ def test_sparse_nd_negate(): def test_sparse_nd_output_fallback(): shape = (10, 10) - out = mx.nd.zeros(shape=shape, storage_type='row_sparse') + out = mx.nd.zeros(shape=shape, stype='row_sparse') mx.nd.random_normal(shape=shape, out=out) assert(np.sum(out.asnumpy()) != 0) @@ -336,7 +332,7 @@ def test_sparse_nd_output_fallback(): def test_sparse_nd_astype(): stypes = ['row_sparse', 'csr'] for stype in stypes: - x = mx.nd.zeros(shape=rand_shape_2d(), storage_type=stype, dtype='float32') + x = mx.nd.zeros(shape=rand_shape_2d(), stype=stype, dtype='float32') y = x.astype('int32') assert(y.dtype == np.int32), y.dtype diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 2466e6a94512..d0064a9265f8 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -2,13 +2,8 @@ def check_elemwise_add_ex(lhs_stype, rhs_stype, shape, lhs_grad_stype=None, rhs_grad_stype=None): - lhs = mx.symbol.Variable('lhs', storage_type=lhs_stype) - rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype) - if lhs_grad_stype is not None: - lhs._set_attr(grad_stype_hint=str(lhs_grad_stype)) - if rhs_grad_stype is not None: - rhs._set_attr(grad_stype_hint=str(rhs_grad_stype)) - + lhs = mx.symbol.Variable('lhs', stype=lhs_stype) + rhs = mx.symbol.Variable('rhs', stype=rhs_stype) lhs_nd = rand_ndarray(shape, lhs_stype) rhs_nd = rand_ndarray(shape, rhs_stype) lhs_np = lhs_nd.asnumpy() @@ -19,7 +14,13 @@ def check_elemwise_add_ex(lhs_stype, rhs_stype, shape, lhs_grad_stype=None, rhs_ location = {'lhs': lhs_nd, 'rhs': rhs_nd} check_symbolic_forward(test, location, [out_np]) check_numeric_gradient(test, location) - check_symbolic_backward(test, location, [out_np], [out_np, out_np]) + grad_stypes = {} + if lhs_grad_stype is not None and lhs_grad_stype != 'default': + grad_stypes['lhs'] = lhs_grad_stype + if rhs_grad_stype is not None and rhs_grad_stype != 'default': + grad_stypes['rhs'] = rhs_grad_stype + check_symbolic_backward(test, location, [out_np], [out_np, out_np], + grad_stypes=grad_stypes) def test_elemwise_add_ex(): @@ -43,13 +44,13 @@ def test_elemwise_add_ex_multiple_stages(): val2 = mx.nd.array([[5, 10]]); idx1 = mx.nd.array([0], dtype=np.int64); idx2 = mx.nd.array([1], dtype=np.int64); - sp_nd1 = mx.sparse_nd.row_sparse(val1, idx1, shape) - sp_nd2 = mx.sparse_nd.row_sparse(val2, idx2, shape) + sp_nd1 = mx.nd.row_sparse(val1, idx1, shape) + sp_nd2 = mx.nd.row_sparse(val2, idx2, shape) ds_nd = mx.nd.array(ds_np) # sparse + sparse = sparse - sp_data1 = mx.symbol.Variable('sp_data1', storage_type='row_sparse') - sp_data2 = mx.symbol.Variable('sp_data2', storage_type='row_sparse') + sp_data1 = mx.symbol.Variable('sp_data1', stype='row_sparse') + sp_data2 = mx.symbol.Variable('sp_data2', stype='row_sparse') ds_data = mx.symbol.Variable('ds_data') plus = mx.symbol.elemwise_add(sp_data1, sp_data2, name='plus') # sparse + dense = dense @@ -69,7 +70,7 @@ def test_elemwise_add_ex_multiple_stages(): def test_cast_storage_ex(): def test_rsp_to_dns(shape): rsp, (data, row_idx) = rand_sparse_ndarray(shape, 'row_sparse') - dns_out = mx.nd.cast_storage(rsp, storage_type='default') + dns_out = mx.nd.cast_storage(rsp, stype='default') dns_expected = np.zeros(shape, dtype=default_dtype()) if row_idx is not None: for k, v in enumerate(row_idx): @@ -78,8 +79,8 @@ def test_rsp_to_dns(shape): def test_dns_to_rsp(shape): dns_in = rand_ndarray(shape, 'default') - rsp_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), storage_type='row_sparse') - ret = mx.nd.cast_storage(rsp_out, storage_type='default') + rsp_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), stype='row_sparse') + ret = mx.nd.cast_storage(rsp_out, stype='default') assert same(ret.asnumpy(), dns_in.asnumpy()) def test_csr_to_dns(shape): @@ -90,8 +91,8 @@ def test_csr_to_dns(shape): def test_dns_to_csr(dns_in): dns_in = np.array(dns_in) - csr_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), storage_type='csr') - ret = mx.nd.cast_storage(csr_out, storage_type='default') + csr_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), stype='csr') + ret = mx.nd.cast_storage(csr_out, stype='default') assert same(ret.asnumpy(), dns_in) shape = rand_shape_2d() @@ -119,8 +120,8 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1): assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5) # test symbolic forward - lhs = mx.symbol.Variable('lhs', storage_type='csr') - rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype) + lhs = mx.symbol.Variable('lhs', stype='csr') + rhs = mx.symbol.Variable('rhs', stype=rhs_stype) test = mx.symbol.dot(lhs, rhs, transpose_a=trans_lhs) location = {'lhs': lhs_nd, 'rhs': rhs_nd} expected = {'rhs': rhs_backward_grad} @@ -146,7 +147,7 @@ def test_sparse_embedding(): out_dim = 4 batch = 24 - data = mx.sym.Variable("data", storage_type='csr') + data = mx.sym.Variable("data", stype='csr') embed = mx.sym.SparseEmbedding(data=data, input_dim=in_dim, output_dim=out_dim, name="embed") exe_test = embed.simple_bind(default_context(), grad_req={'data': 'null', 'embed_weight': 'write'}, data=(batch, in_dim)) @@ -190,7 +191,7 @@ def test_sparse_retain(): for _ in range(10): shape = rand_shape_2d() num_rows = shape[0] - rsp, _ = rand_sparse_ndarray(shape=shape, storage_type='row_sparse', density=0.5) + rsp, _ = rand_sparse_ndarray(shape=shape, stype='row_sparse', density=0.5) length = np.random.randint(1, num_rows + 1) idx = random_sample(list(range(0, num_rows)), length) idx.sort()