Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Revert "Revert "Change storage_type to stype""
Browse files Browse the repository at this point in the history
This reverts commit 0932838.

Move ndarray.py, sparse_ndarray.py, ndarray_utils.py, and _ndarray_internal to ndarrary folder

More refactor

Move elementwise sum for rsp to ndarray_function.cc

Remove unnecessary import in ndarray module

Fix pylint

Remove redundant code

Remove _stype from slots

Fix cpp-package build error caused by the change to imperative invoke interface

Use relative import

Remove print line

Rename _ndarray_internal.py to _internal.py
  • Loading branch information
reminisce committed Jul 11, 2017
1 parent 75fda82 commit 0efb93a
Show file tree
Hide file tree
Showing 32 changed files with 835 additions and 742 deletions.
143 changes: 82 additions & 61 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,17 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
const mx_uint *shape,
mx_uint ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
mx_uint num_aux,
int *aux_type,
mx_uint *aux_ndims,
const mx_uint *aux_shape,
NDArrayHandle *out);
const mx_uint *shape,
mx_uint ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
mx_uint num_aux,
int *aux_type,
mx_uint *aux_ndims,
const mx_uint *aux_shape,
NDArrayHandle *out);

/*!
* \brief create a NDArray handle that is loaded from raw bytes.
Expand Down Expand Up @@ -406,7 +406,7 @@ MXNET_DLL int MXNDArrayAt(NDArrayHandle handle,
* \brief get the storage type of the array
*/
MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle,
int *out_storage_type);
int *out_storage_type);

/*!
* \brief Reshape the NDArray.
Expand Down Expand Up @@ -604,8 +604,29 @@ MXNET_DLL int MXImperativeInvoke(AtomicSymbolCreator creator,
NDArrayHandle **outputs,
int num_params,
const char **param_keys,
const char **param_vals,
const int** out_stypes);
const char **param_vals);
/*!
* \brief invoke a nnvm op and imperative function
* \param creator the op
* \param num_inputs number of input NDArrays
* \param inputs input NDArrays
* \param num_outputs number of output NDArrays
* \param outputs output NDArrays
* \param num_params number of keyword parameters
* \param param_keys keys for keyword parameters
* \param param_vals values for keyword parameters
* \param out_stypes output ndarrays' stypes
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXImperativeInvokeEx(AtomicSymbolCreator creator,
int num_inputs,
NDArrayHandle *inputs,
int *num_outputs,
NDArrayHandle **outputs,
int num_params,
const char **param_keys,
const char **param_vals,
const int **out_stypes);
/*!
* \brief set whether to record operator for autograd
* \param is_train 1 when training, 0 when testing
Expand Down Expand Up @@ -1020,20 +1041,20 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_uint *arg_ind_ptr,
const mx_uint *arg_shape_data,
mx_uint *in_shape_size,
const mx_uint **in_shape_ndim,
const mx_uint ***in_shape_data,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
mx_uint *aux_shape_size,
const mx_uint **aux_shape_ndim,
const mx_uint ***aux_shape_data,
int *complete);
mx_uint num_args,
const char** keys,
const mx_uint *arg_ind_ptr,
const mx_uint *arg_shape_data,
mx_uint *in_shape_size,
const mx_uint **in_shape_ndim,
const mx_uint ***in_shape_data,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
mx_uint *aux_shape_size,
const mx_uint **aux_shape_ndim,
const mx_uint ***aux_shape_data,
int *complete);

/*!
* \brief infer type of unknown input types given the known one.
Expand Down Expand Up @@ -1216,39 +1237,39 @@ MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle,
ExecutorHandle *out);

MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
const mx_uint num_g2c_keys,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const mx_uint provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
const char** shared_buffer_name_list,
NDArrayHandle* shared_buffer_handle_list,
const char*** updated_shared_buffer_name_list,
NDArrayHandle** updated_shared_buffer_handle_list,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
mx_uint* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);
int dev_type,
int dev_id,
const mx_uint num_g2c_keys,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const mx_uint provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
const char** shared_buffer_name_list,
NDArrayHandle* shared_buffer_handle_list,
const char*** updated_shared_buffer_name_list,
NDArrayHandle** updated_shared_buffer_handle_list,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
mx_uint* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);
/*!
* \brief set a call back to notify the completion of operation
*/
Expand Down
5 changes: 0 additions & 5 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@

namespace mxnet {

namespace ndarray {
template<typename from_xpu, typename to_xpu>
void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, RunContext ctx);
};

namespace autograd {
class AGNode;

Expand Down
8 changes: 4 additions & 4 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ using FCompute = std::function<void (const nnvm::NodeAttrs& attrs,
* Dispatched only when operators process non-default storage inputs or outputs
*/
using FComputeEx = std::function<void (const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs)>;
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs)>;

using FInferStorageType = std::function<bool (const NodeAttrs& attrs,
const Context& ctx,
Expand Down
7 changes: 1 addition & 6 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from . import base
from . import contrib
from . import ndarray
from . import sparse_ndarray
from . import ndarray_utils
from . import ndarray as nd
from . import name
# use mx.sym as short for symbol
from . import symbol as sym
Expand All @@ -18,10 +17,6 @@
from . import io
from . import recordio
from . import operator
# use mx.nd as short for mx.ndarray
from . import ndarray as nd
from . import sparse_ndarray as sparse_nd
from . import ndarray_utils as nd_utils
# use mx.rnd as short for mx.random
from . import random as rnd
from . import random
Expand Down
11 changes: 5 additions & 6 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@

class NDArrayBase(object):
"""Base data structure for ndarray"""
__slots__ = ["handle", "writable", "_stype"]
__slots__ = ["handle", "writable"]
# pylint: disable= no-member

def __init__(self, handle, writable=True, stype=None):
def __init__(self, handle, writable=True):
"""initialize a new NDArray
Parameters
Expand All @@ -41,7 +41,6 @@ def __init__(self, handle, writable=True, stype=None):
assert isinstance(handle, NDArrayHandle)
self.handle = handle
self.writable = writable
self._stype = stype

def __del__(self):
check_call(_LIB.MXNDArrayFree(self.handle))
Expand Down Expand Up @@ -76,7 +75,7 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
# a handle's stype in _ndarray_cls
out_stypes = ctypes.POINTER(ctypes.c_int)()

check_call(_LIB.MXImperativeInvoke(
check_call(_LIB.MXImperativeInvokeEx(
ctypes.c_void_p(handle),
ctypes.c_int(len(ndargs)),
c_array(NDArrayHandle, [arr.handle for arr in ndargs]),
Expand All @@ -93,8 +92,8 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]])
else:
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle,
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]]))
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]])
for i in range(num_output.value)]


Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .base import mx_uint, NDArrayHandle, ExecutorHandle
from .base import check_call, c_array, py_str
from .ndarray import NDArray
from .sparse_ndarray import _ndarray_cls
from .ndarray import _ndarray_cls
from . import ndarray as nd

# those functions are not used here, we just import them to keep backward compatibility
Expand Down
6 changes: 3 additions & 3 deletions python/mxnet/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from .base import numeric_types
from . import ndarray as nd
from . import _ndarray_internal as _internal
from ._ndarray_internal import _cvimresize as imresize
from ._ndarray_internal import _cvcopyMakeBorder as copyMakeBorder
from .ndarray import _internal
from .ndarray._internal import _cvimresize as imresize
from .ndarray._internal import _cvcopyMakeBorder as copyMakeBorder
from . import io
from . import recordio

Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .base import mx_real_t
from .base import check_call, build_param_doc as _build_param_doc
from .ndarray import NDArray
from .sparse_ndarray import _ndarray_cls
from .ndarray import _ndarray_cls
from .ndarray import array
from .ndarray import concatenate

Expand Down
7 changes: 3 additions & 4 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
import logging
import warnings

import mxnet as mx
from .. import context as ctx
from .. import ndarray as nd
from .. import optimizer as opt

from .executor_group import DataParallelExecutorGroup
from ..model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore
from ..model import load_checkpoint
from ..initializer import Uniform, InitDesc
from ..ndarray import zeros

from .base_module import BaseModule, _check_input_names, _parse_data_desc

Expand Down Expand Up @@ -399,13 +398,13 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
else:
assert self._arg_params is None and self._aux_params is None
param_arrays = [
mx.nd.zeros(shape=x[0].shape, dtype=x[0].dtype, storage_type=x[0].stype)
zeros(shape=x[0].shape, dtype=x[0].dtype, stype=x[0].stype)
for x in self._exec_group.param_arrays
]
self._arg_params = {name:arr for name, arr in zip(self._param_names, param_arrays)}

aux_arrays = [
nd.zeros(x[0].shape, dtype=x[0].dtype)
zeros(x[0].shape, dtype=x[0].dtype)
for x in self._exec_group.aux_arrays
]
self._aux_params = {name:arr for name, arr in zip(self._aux_names, aux_arrays)}
Expand Down
12 changes: 12 additions & 0 deletions python/mxnet/ndarray/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""ndarray module"""

from . import _internal
from . import op
from .op import CachedOp, invoke
from .ndarray import NDArray, array, concatenate, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
from .ndarray import empty, ones, add, arange, divide, equal, full, greater, greater_equal, imdecode
from .ndarray import lesser, lesser_equal, maximum, minimum, moveaxis, multiply, negative, not_equal
from .ndarray import onehot_encode, power, subtract, true_divide, waitall, _new_empty_handle
from .ndarray_utils import load, save, zeros
from .sparse_ndarray import _ndarray_cls
from .sparse_ndarray import csr, row_sparse, SparseNDArray, todense, RowSparseNDArray, CSRNDArray
File renamed without changes.
Loading

0 comments on commit 0efb93a

Please sign in to comment.