From 541403fe17549bd3978e623bdf410a24f708f680 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 2 Jun 2021 20:41:11 +0900 Subject: [PATCH] fix tests --- .../tvm/topi/testing/strided_slice_python.py | 7 +++---- tests/python/relay/test_op_level4.py | 20 +++++++++---------- .../python/topi/python/test_topi_transform.py | 2 +- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/python/tvm/topi/testing/strided_slice_python.py b/python/tvm/topi/testing/strided_slice_python.py index e630ace696778..3843d0996777c 100644 --- a/python/tvm/topi/testing/strided_slice_python.py +++ b/python/tvm/topi/testing/strided_slice_python.py @@ -17,7 +17,7 @@ """strided_slice/set in python""" -def strided_slice_python(data, begin, end, strides, axes=None, slice_mode="end"): +def strided_slice_python(data, begin, end, strides, slice_mode="end", axes=None): """Python version of strided slice operator. Parameters @@ -34,9 +34,6 @@ def strided_slice_python(data, begin, end, strides, axes=None, slice_mode="end") strides : list The stride of each slice. - axes : list, optional - Axes along which slicing is applied - slice_mode : str, optional The slice mode [end, size]. end: The default slice mode, ending indices for the slice. @@ -44,6 +41,8 @@ def strided_slice_python(data, begin, end, strides, axes=None, slice_mode="end") the sizeof a slice starting at the location specified by begin. If end[i] is -1, all remaining elements in that dimension are included in the slice. + axes : list, optional + Axes along which slicing is applied Returns ------- diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 607b68ea77cce..2a50580b8c7a6 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -398,7 +398,7 @@ def verify( # target numpy result x_data = np.random.uniform(size=dshape).astype("float32") ref_res = tvm.topi.testing.strided_slice_python( - x_data, begin, end, strides, axes, slice_mode + x_data, begin, end, strides, slice_mode, axes=axes, ) if strides: @@ -474,7 +474,7 @@ def verify( # target numpy result x_data = np.random.uniform(size=dshape).astype("float32") ref_res = tvm.topi.testing.strided_slice_python( - x_data, begin, end, strides, axes, slice_mode + x_data, begin, end, strides, slice_mode, axes=axes ) if ishape is None: @@ -576,11 +576,11 @@ def verify(dshape, begin, end, strides, vshape, test_ref=True): if __name__ == "__main__": test_strided_slice() test_dyn_strided_slice() - test_strided_set() - test_binary_op() - test_cmp_type() - test_binary_int_broadcast_1() - test_binary_int_broadcast_2() - test_where() - test_reduce_functions() - test_mean_var_std() + # test_strided_set() + # test_binary_op() + # test_cmp_type() + # test_binary_int_broadcast_1() + # test_binary_int_broadcast_2() + # test_where() + # test_reduce_functions() + # test_mean_var_std() diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index c98478885a102..ddde2e20e754c 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -416,7 +416,7 @@ def check_device(target): foo = tvm.build(s, [A, B], target, name="stride_slice") x_np = np.random.uniform(size=in_shape).astype(A.dtype) - out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides, axes) + 1 + out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides, axes=axes) + 1 data_nd = tvm.nd.array(x_np, dev) out_nd = tvm.nd.empty(out_npy.shape, device=dev, dtype=A.dtype) foo(data_nd, out_nd)