Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tkonolige committed Nov 5, 2020
1 parent f0056f7 commit 7ba042f
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 5 deletions.
5 changes: 3 additions & 2 deletions python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,10 @@ def scatter_nd(data, indices, shape):
f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of "
f"data[{i}] ({data.shape[i]})."
)
for i in range(int(indices.shape[0]), len(shape)):
mdim = int(indices.shape[0])
for i in range(mdim, len(shape)):
assert (
data.shape[i] == out_shape[i]
data.shape[i-mdim] == shape[i]
), f"Dimension of data[{i}] must equal dimension of out_shape[{i}]"

assert (
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .space_to_depth import space_to_depth_python
from .crop_and_resize_python import crop_and_resize_python
from .common import (
compare_numpy_tvm,
get_injective_schedule,
get_reduce_schedule,
get_broadcast_schedule,
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/topi/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

import tvm
from tvm import topi
from tvm.testing import assert_allclose

import numpy as np

_injective_schedule = {
"generic": topi.generic.schedule_injective,
Expand Down
5 changes: 4 additions & 1 deletion tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,10 @@ def _body(i, st):
body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
func = relay.Function([start], relay.TupleGetItem(body, 1))
with DiagnosticTesting() as diagnostics:
diagnostics.assert_message("in particular dimension 0 conflicts 2 does not match 1")
diagnostics.assert_message(
"The Relay type checker is unable to show the following types "
"match.\nIn particular dimension 0 conflicts: 2 does not match 1."
)
func = infer_type(func)


Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_op_grad_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def test_gather_nd_grad():
indices = relay.var("indices", relay.TensorType((2, 4), "int64"))
fwd = relay.Function([data, indices], relay.gather_nd(data, indices))
data_np = np.random.rand(2, 3)
indices_np = np.array([[0, 2, 1, 0], [0, 1, 2, 1]])
check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[indices_np])
indices_np = np.array([[0, 1, 1, 0], [0, 1, 2, 1]])
check_grad(fwd, inputs=[data_np, indices_np], test_inputs=indices_np)


if __name__ == "__main__":
Expand Down

0 comments on commit 7ba042f

Please sign in to comment.