Skip to content

Commit

Permalink
fix nnvm
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Mar 27, 2019
1 parent e0444be commit 23e1659
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions nnvm/tests/python/compiler/test_top_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1),
m = graph_runtime.create(graph, lib, ctx)
m.set_input("data", np.random.uniform(size=dshape).astype(dtype))
m.run()
out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype))
tvm.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)
tvm_out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype))
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)

def test_multibox_prior():
verify_multibox_prior((1, 3, 50, 50))
Expand Down Expand Up @@ -568,8 +568,8 @@ def test_multibox_transform_loc():
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"cls_prob": np_cls_prob.astype(dtype), "loc_preds": np_loc_preds.astype(dtype), "anchors": np_anchors.astype(dtype)})
m.run()
out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype))
tvm.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5)
tvm_out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype))
tvm.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5)

def test_non_max_suppression():
dshape = (1, 5, 6)
Expand All @@ -595,8 +595,8 @@ def test_non_max_suppression():
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"data": np_data, "valid_count": np_valid_count})
m.run()
out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32"))
tvm.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)
tvm_outout = m.get_output(0, tvm.nd.empty(np_result.shape, "float32"))
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)

def np_slice_like(np_data, np_shape_like, axis=[]):
begin_idx = [0 for _ in np_data.shape]
Expand Down Expand Up @@ -707,7 +707,6 @@ def test_argmax():
np.testing.assert_allclose(out.asnumpy(), np_argmax, atol=1e-5, rtol=1e-5)

if __name__ == "__main__":
test_non_max_suppression()
test_reshape()
test_broadcast()
test_reduce()
Expand Down

0 comments on commit 23e1659

Please sign in to comment.