Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jun 2, 2021
1 parent f9379fb commit 541403f
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 15 deletions.
7 changes: 3 additions & 4 deletions python/tvm/topi/testing/strided_slice_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""strided_slice/set in python"""


def strided_slice_python(data, begin, end, strides, axes=None, slice_mode="end"):
def strided_slice_python(data, begin, end, strides, slice_mode="end", axes=None):
"""Python version of strided slice operator.
Parameters
Expand All @@ -34,16 +34,15 @@ def strided_slice_python(data, begin, end, strides, axes=None, slice_mode="end")
strides : list
The stride of each slice.
axes : list, optional
Axes along which slicing is applied
slice_mode : str, optional
The slice mode [end, size].
end: The default slice mode, ending indices for the slice.
size: The input strides will be ignored, input end in this mode indicates
the sizeof a slice starting at the location specified by begin. If end[i] is -1,
all remaining elements in that dimension are included in the slice.
axes : list, optional
Axes along which slicing is applied
Returns
-------
Expand Down
20 changes: 10 additions & 10 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def verify(
# target numpy result
x_data = np.random.uniform(size=dshape).astype("float32")
ref_res = tvm.topi.testing.strided_slice_python(
x_data, begin, end, strides, axes, slice_mode
x_data, begin, end, strides, slice_mode, axes=axes,
)

if strides:
Expand Down Expand Up @@ -474,7 +474,7 @@ def verify(
# target numpy result
x_data = np.random.uniform(size=dshape).astype("float32")
ref_res = tvm.topi.testing.strided_slice_python(
x_data, begin, end, strides, axes, slice_mode
x_data, begin, end, strides, slice_mode, axes=axes
)

if ishape is None:
Expand Down Expand Up @@ -576,11 +576,11 @@ def verify(dshape, begin, end, strides, vshape, test_ref=True):
if __name__ == "__main__":
test_strided_slice()
test_dyn_strided_slice()
test_strided_set()
test_binary_op()
test_cmp_type()
test_binary_int_broadcast_1()
test_binary_int_broadcast_2()
test_where()
test_reduce_functions()
test_mean_var_std()
# test_strided_set()
# test_binary_op()
# test_cmp_type()
# test_binary_int_broadcast_1()
# test_binary_int_broadcast_2()
# test_where()
# test_reduce_functions()
# test_mean_var_std()
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def check_device(target):

foo = tvm.build(s, [A, B], target, name="stride_slice")
x_np = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides, axes) + 1
out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides, axes=axes) + 1
data_nd = tvm.nd.array(x_np, dev)
out_nd = tvm.nd.empty(out_npy.shape, device=dev, dtype=A.dtype)
foo(data_nd, out_nd)
Expand Down

0 comments on commit 541403f

Please sign in to comment.