Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
hgt312 committed Oct 28, 2019
1 parent 69ddf4d commit db81cc1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'tril',
'meshgrid',
'outer',
'einsum'
'einsum',
'shares_memory',
'may_share_memory',
]


Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,8 @@ def check_interoperability(op_list):
for name in op_list:
if name in _TVM_OPS and not is_op_runnable():
continue
if name in ['shares_memory', 'may_share_memory']: # skip list
continue
print('Dispatch test:', name)
workloads = OpArgMngr.get_workloads(name)
assert workloads is not None, 'Workloads for operator `{}` has not been ' \
Expand All @@ -1243,6 +1245,19 @@ def check_interoperability(op_list):
_check_interoperability_helper(name, *workload['args'], **workload['kwargs'])


@with_seed()
@use_np
@with_array_function_protocol
def test_np_memory_array_function():
ops = [_np.shares_memory, _np.may_share_memory]
for op in ops:
data_mx = np.zeros([13, 21, 23, 22], dtype=np.float32)
data_np = _np.zeros([13, 21, 23, 22], dtype=np.float32)
assert op(data_mx[0,:,:,:], data_mx[1,:,:,:]) == op(data_np[0,:,:,:], data_np[1,:,:,:])
assert op(data_mx[0,0,0,2:5], data_mx[0,0,0,4:7]) == op(data_np[0,0,0,2:5], data_np[0,0,0,4:7])
assert op(data_mx, np.ones((5, 0))) == op(data_np, _np.ones((5, 0)))


@with_seed()
@use_np
@with_array_function_protocol
Expand Down

0 comments on commit db81cc1

Please sign in to comment.