diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index cec2f245a5e1..6a5f166a70eb 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -127,7 +127,9 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'tril', 'meshgrid', 'outer', - 'einsum' + 'einsum', + 'shares_memory', + 'may_share_memory', ] diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 860fecc5cda0..624fc0a107b0 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1234,6 +1234,8 @@ def check_interoperability(op_list): for name in op_list: if name in _TVM_OPS and not is_op_runnable(): continue + if name in ['shares_memory', 'may_share_memory']: # skip list + continue print('Dispatch test:', name) workloads = OpArgMngr.get_workloads(name) assert workloads is not None, 'Workloads for operator `{}` has not been ' \ @@ -1243,6 +1245,19 @@ def check_interoperability(op_list): _check_interoperability_helper(name, *workload['args'], **workload['kwargs']) +@with_seed() +@use_np +@with_array_function_protocol +def test_np_memory_array_function(): + ops = [_np.shares_memory, _np.may_share_memory] + for op in ops: + data_mx = np.zeros([13, 21, 23, 22], dtype=np.float32) + data_np = _np.zeros([13, 21, 23, 22], dtype=np.float32) + assert op(data_mx[0,:,:,:], data_mx[1,:,:,:]) == op(data_np[0,:,:,:], data_np[1,:,:,:]) + assert op(data_mx[0,0,0,2:5], data_mx[0,0,0,4:7]) == op(data_np[0,0,0,2:5], data_np[0,0,0,4:7]) + assert op(data_mx, np.ones((5, 0))) == op(data_np, _np.ones((5, 0))) + + @with_seed() @use_np @with_array_function_protocol