diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index 6c1d1ab39176..e8c87729f946 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -241,9 +241,10 @@ def scatter_nd(data, indices, shape): f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " f"data[{i}] ({data.shape[i]})." ) - for i in range(int(indices.shape[0]), len(shape)): + mdim = int(indices.shape[0]) + for i in range(mdim, len(shape)): assert ( - data.shape[i] == out_shape[i] + data.shape[i-mdim] == shape[i] ), f"Dimension of data[{i}] must equal dimension of out_shape[{i}]" assert ( diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 5b23e8f4600e..18a46b17bb0a 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -57,6 +57,7 @@ from .space_to_depth import space_to_depth_python from .crop_and_resize_python import crop_and_resize_python from .common import ( + compare_numpy_tvm, get_injective_schedule, get_reduce_schedule, get_broadcast_schedule, diff --git a/python/tvm/topi/testing/common.py b/python/tvm/topi/testing/common.py index 35a6040fa25a..5639662d5a9d 100644 --- a/python/tvm/topi/testing/common.py +++ b/python/tvm/topi/testing/common.py @@ -19,6 +19,9 @@ import tvm from tvm import topi +from tvm.testing import assert_allclose + +import numpy as np _injective_schedule = { "generic": topi.generic.schedule_injective, diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 546973704fea..eec6aa21c69b 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -989,7 +989,10 @@ def _body(i, st): body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) func = relay.Function([start], relay.TupleGetItem(body, 1)) with DiagnosticTesting() as diagnostics: - diagnostics.assert_message("in particular dimension 0 conflicts 2 does not match 1") + diagnostics.assert_message( + "The Relay type checker is unable to show the following types " + "match.\nIn particular dimension 0 conflicts: 2 does not match 1." + ) func = infer_type(func) diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index a5cb916da613..a7443c65ac8c 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -122,8 +122,8 @@ def test_gather_nd_grad(): indices = relay.var("indices", relay.TensorType((2, 4), "int64")) fwd = relay.Function([data, indices], relay.gather_nd(data, indices)) data_np = np.random.rand(2, 3) - indices_np = np.array([[0, 2, 1, 0], [0, 1, 2, 1]]) - check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[indices_np]) + indices_np = np.array([[0, 1, 1, 0], [0, 1, 2, 1]]) + check_grad(fwd, inputs=[data_np, indices_np], test_inputs=indices_np) if __name__ == "__main__":