Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Showing proper error message when an attempt is made to create large …
Browse files Browse the repository at this point in the history
…tensor but MXNet is not built with it (#16570)

* added error message when creating NDarray and syncing copies to and from CPU

* bug fix for 1>>31 -> uint32_t{1}<<31 and showing error message when user attempts to pass large ndarrays as inputs to operators or large shape of inputs for memory allocation.

* adding error messages to other operators where user can create NDArray indirectly

* bug fix

* removing additional checks

* Revert "removing additional checks"

This reverts commit d035559.

* adding tests and comments. Removed int64 check from linspace
  • Loading branch information
access2rohit authored and zheng-da committed Oct 29, 2019
1 parent 86ed5f5 commit 60d74bc
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 12 deletions.
12 changes: 6 additions & 6 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,12 +596,12 @@ MXNET_DLL int MXNDArrayCreate(const uint32_t *shape,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayCreateEx(const uint32_t *shape,
uint32_t ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
NDArrayHandle *out);
uint32_t ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
NDArrayHandle *out);

MXNET_DLL int MXNDArrayCreateEx64(const int64_t *shape,
int ndim,
Expand Down
10 changes: 10 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
_STORAGE_TYPE_DEFAULT = 0
_STORAGE_TYPE_ROW_SPARSE = 1
_STORAGE_TYPE_CSR = 2
_SIGNED_INT32_UPPER_LIMIT = (2**31 - 1)

# pylint: disable= no-member
_DTYPE_NP_TO_MX = {
Expand Down Expand Up @@ -155,6 +156,15 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
else:
# When shape is larger than unit32 then there is an overflow error at python end itself.
# It needs to be caught here since the call doesn't even reach backend.
size = 1
for idx in shape:
size = size * idx
if size > _SIGNED_INT32_UPPER_LIMIT:
raise Exception("[_new_alloc_handle] Size of tensor you are trying to allocate is " +
"larger than 2^31 elements. Please build with flag " +
"USE_INT64_TENSOR_SIZE=1")
check_call(_LIB.MXNDArrayCreateEx(
c_array_buf(mx_uint, native_array('I', shape)),
mx_uint(len(shape)),
Expand Down
7 changes: 6 additions & 1 deletion python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from ..base import check_call, MXNetError, NotImplementedForSymbol
from ..context import Context, current_context
from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP
from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, _int64_enabled
from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, _int64_enabled, _SIGNED_INT32_UPPER_LIMIT
from ..ndarray import _ndarray_cls
from ..executor import Executor
from . import _internal
Expand Down Expand Up @@ -1237,6 +1237,11 @@ def _infer_shape_impl(self, partial, *args, **kwargs):
ctypes.byref(aux_shape_data),
ctypes.byref(complete)))
else:
for size in sdata:
if size > _SIGNED_INT32_UPPER_LIMIT:
raise Exception("[_infer_shape_impl] Size of tensor you are trying to " +
"allocate is larger than 2^31 elements. Please build " +
"with flag USE_INT64_TENSOR_SIZE=1")
arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
Expand Down
13 changes: 12 additions & 1 deletion src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,13 @@ void CreateNDArray(const DataType* shape,
int delay_alloc,
int dtype,
NDArrayHandle* out) {
*out = new NDArray(mxnet::TShape(shape, shape + ndim),
mxnet::TShape requested_shape = mxnet::TShape(shape, shape + ndim);
if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(requested_shape.Size(), (int64_t{1} << 31) - 1) <<
"[CreateNDArray] Size of tensor you are trying to allocate is larger than "
"2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
}
*out = new NDArray(requested_shape,
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
delay_alloc != 0, dtype);
}
Expand Down Expand Up @@ -608,6 +614,11 @@ inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* out_dim
MXAPIThreadLocalEntry<dtype>* ret) {
NDArray* arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(arr->shape().Size(), (int64_t{1} << 31) - 1) <<
"[Get Shape] Size of tensor you are trying to allocate is larger than "
"2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
}
mxnet::TShape s = arr->shape();
if (!Imperative::Get()->is_np_shape()) {
common::ConvertToLegacyShape(&s);
Expand Down
8 changes: 7 additions & 1 deletion src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ void SetNDInputsOutputs(const nnvm::Op* op,
ndinputs->clear();
ndinputs->reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
ndinputs->emplace_back(reinterpret_cast<NDArray*>(inputs[i]));
NDArray* inp = reinterpret_cast<NDArray*>(inputs[i]);
if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(inp->shape().Size(), (int64_t{1} << 31) - 1) <<
"[SetNDInputsOutputs] Size of tensor you are trying to allocate is larger than "
"2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
}
ndinputs->emplace_back(inp);
}

ndoutputs->clear();
Expand Down
15 changes: 15 additions & 0 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ void NDArray::Chunk::CheckAndAllocData(const mxnet::TShape &shape, int dtype) {
CHECK_NE(aux_shapes.size(), 0)
<< "data is expected to be allocated after aux_data";
auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype);
if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(shape.Size(), (int64_t{1} << 31) - 1) <<
"[CheckAndAllocData] Size of tensor you are trying to allocate is larger than "
"2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
}
if (shandle.size < dbytes) {
// free storage
Storage::Get()->Free(shandle);
Expand Down Expand Up @@ -1884,6 +1889,11 @@ NDArray NDArray::Copy(Context ctx) const {

void NDArray::SyncCopyFromCPU(const void *data, size_t size) const {
mxnet::TShape dshape = this->shape();
if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(size, (int64_t{1} << 31) - 1) <<
"[SyncCopyFromCPU] Size of tensor you are trying to allocate is larger than "
"2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
}
CHECK_EQ(dshape.Size(), size)
<< "Memory size do not match";
// zero-size array, no need to copy
Expand Down Expand Up @@ -2019,6 +2029,11 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) {

void NDArray::SyncCopyToCPU(void *data, size_t size) const {
mxnet::TShape dshape = this->shape();
if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(size, (int64_t{1} << 31) - 1) <<
"[SyncCopyToCPU] Size of tensor you are trying to allocate is larger than "
"2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
}
CHECK_EQ(dshape.Size(), size)
<< "Memory size do not match";
// zero-size array, no need to copy
Expand Down
5 changes: 5 additions & 0 deletions src/ndarray/ndarray_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ void Copy<cpu, cpu>(const TBlob &from, TBlob *to,
RunContext ctx) {
MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, {
if (to->type_flag_ == from.type_flag_) {
if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(from.Size(), (int64_t{1} << 31) - 1) <<
"Size of tensor you are trying to allocate is larger than "
"2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
}
const index_t size = static_cast<index_t>(from.Size());
CHECK_EQ(size, to->Size()) << "copying size mismatch, from: " << size * sizeof(DType)
<< " bytes, to: " << to->Size() * sizeof(DType) << " bytes.";
Expand Down
30 changes: 27 additions & 3 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,22 @@ inline bool InitShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape param_shape = param.shape;
if (shape_is_known(param_shape) && !features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(param_shape.Size(), (int64_t{1} << 31) - 1) <<
"[InitShape-input] Size of tensor you are trying to allocate is larger than "
"2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
}
if (!Imperative::Get()->is_np_shape()) {
common::ConvertToNumpyShape(&param_shape);
}
if (shape_is_known((*out_attrs)[0]) && !shape_is_known(param_shape)) return true;
if (shape_is_known((*out_attrs)[0]) && !shape_is_known(param_shape)) {
if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(out_attrs->at(0).Size() , (int64_t{1} << 31) - 1) <<
"[InitShape-output] Size of tensor you are trying to allocate is larger than "
"2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
}
return true;
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, param_shape);
return shape_is_known(out_attrs->at(0));
}
Expand Down Expand Up @@ -336,6 +348,11 @@ inline bool InitStorageType(const nnvm::NodeAttrs& attrs,
template <bool is_integer = false, typename ValueType, typename xpu>
void Fill(mshadow::Stream<xpu> *s, const TBlob& b, const OpReqType req, ValueType val) {
// If b is a zero-size tensor, do nothing.
if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(b.Size(), (int64_t{1} << 31) - 1) <<
"[Fill] Size of tensor you are trying to allocate is larger than "
"2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
}
if (b.Size() == 0) return;
if (req != kNullOp) {
const size_t size = b.Size();
Expand Down Expand Up @@ -580,7 +597,13 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs,
}
const double out_size = std::ceil((param.stop.value() - param.start) / param.step)
* param.repeat;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
mxnet::TShape output_shape = mxnet::TShape({static_cast<nnvm::dim_t>(out_size)});
if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(output_shape.Size(), (int64_t{1} << 31) - 1) <<
"[RangeShape] Size of tensor you are trying to allocate is larger than "
"2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, output_shape);
return true;
}

Expand Down Expand Up @@ -622,7 +645,8 @@ inline bool LinspaceShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
CHECK_GE(param.num, 0)
<< "Number of sequence should be non-negative, received " << param.num;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({static_cast<nnvm::dim_t>(param.num)}));
mxnet::TShape shape = mxnet::TShape({static_cast<nnvm::dim_t>(param.num)});
SHAPE_ASSIGN_CHECK(*out_attrs, 0, shape);
return true;
}

Expand Down
56 changes: 56 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9279,6 +9279,62 @@ def test_min_max_inf():
assert_array_equal(max_data_np, max_data_mx.asnumpy())


def test_large_tensor_disabled_err_msg():
LARGE_X = 4300000000
MEDIUM_X = 1000000000
SMALL_Y = 1
shape = (2, LARGE_X)

def check_nd_array():
x = np.arange(0, LARGE_X)
assertRaises(MXNetError, mx.nd.array, x)

def check_nd_ones():
assertRaises(MXNetError, mx.nd.ones, shape)

def check_nd_zeros():
assertRaises(MXNetError, mx.nd.zeros, shape)

def check_nd_full():
val = 1
assertRaises(Exception, mx.nd.full, shape, val)

def check_nd_arange():
start = 0
stop = LARGE_X
assertRaises(Exception, mx.nd.arange, start, stop)

def check_nd_random():
shape = (2, LARGE_X)
def check_random_exp():
lam = 4
assertRaises(MXNetError, mx.nd.random_exponential, lam, shape)

def check_random_gamma():
alpha = 9
beta = 0.5
assertRaises(MXNetError, mx.nd.random_gamma, alpha, beta, shape)

def check_random_normal():
loc = 0
scale = 1
assertRaises(MXNetError, mx.nd.random_normal, loc, scale, shape)

def check_random_poisson():
lam = 4
assertRaises(MXNetError, mx.nd.random_poisson, alpha, lam, shape)

def check_random_randint():
low = 0
high = 1000000
assertRaises(MXNetError, mx.nd.random_randint, low, high, shape)

def check_random_uniform():
low = 0
hight = 1
assertRaises(MXNetError, mx.nd.random_uniform, alpha, beta, shape)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 60d74bc

Please sign in to comment.