diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 6207c2c1ed46..114ff2a14f70 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -121,6 +121,20 @@ def cast_shape_func(attrs, inputs, out_ndims): return [_cast_shape_function(*inputs)] # shape func +@script +def _full_shape_func(x): + out_ndim = len(x) + out = output_tensor((out_ndim,), "int64") + for i in const_range(out_ndim): + out[i] = x[i] + return out + +def full_shape_func(attrs, inputs, out_ndims): + """ + Shape func for zeros, zeros_like, ones, ones_like. + """ + return [_full_shape_func(*inputs)] + @script def _broadcast_shape_func(x, y, ndim): out = output_tensor((ndim,), "int64") @@ -162,6 +176,10 @@ def elemwise_shape_func(attrs, inputs, _): return [topi.math.identity(inputs[0])] register_shape_func("cast", False, cast_shape_func) +register_shape_func("zeros", False, full_shape_func) +register_shape_func("zeros_like", False, full_shape_func) +register_shape_func("ones", False, full_shape_func) +register_shape_func("ones_like", False, full_shape_func) register_shape_func("add", False, broadcast_shape_func) register_shape_func("subtract", False, broadcast_shape_func) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 75be88cbcb19..9e0208f4fa58 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -41,10 +41,11 @@ def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op): mod["main"] = relay.Function([x, y], op(x, y)) x_np = np.random.uniform(size=x_np_shape).astype(dtype) y_np = np.random.uniform(size=y_np_shape).astype(dtype) + res_np = np_op(x_np, y_np) for kind in ["debug", "vm"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") result = ex.evaluate()(x_np, y_np) - tvm.testing.assert_allclose(result.asnumpy(), np_op(x_np, y_np)) + tvm.testing.assert_allclose(result.asnumpy(), res_np) def test_any_broadcast(): # Test broadcast with 1s @@ -77,6 +78,32 @@ def check_fail(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op): check_fail((relay.Any(),), (3, 2), (2), (4, 2), relay.add, np.add) +def verify_any_full(x_shape, x_np_shape, relay_op, np_op, dtype='float32'): + x = relay.var('x', shape=x_shape, dtype=dtype) + mod = relay.module.Module() + mod['main'] = relay.Function([x], relay.zeros_like(x)) + x_np = np.random.uniform(size=x_np_shape).astype(dtype) + res_np = np.zeros_like(x_np) + for kind in ['debug', 'vm']: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm') + result = ex.evaluate()(x_np).asnumpy() + tvm.testing.assert_allclose(result, res_np) + +def test_any_full(): + # zeros, zeros_like, ones, ones_like + verify_any_full(any_dims(3), (2, 3, 5), relay.zeros, np.zeros, "float32") + verify_any_full(any_dims(3), (225, 115, 15), relay.zeros, np.zeros, "float32") + verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros, np.zeros, "int32") + verify_any_full(any_dims(3), (2, 3, 5), relay.zeros_like, np.zeros_like, "float32") + verify_any_full(any_dims(3), (225, 115, 15), relay.zeros_like, np.zeros_like, "float32") + verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros_like, np.zeros_like, "int32") + verify_any_full(any_dims(3), (2, 3, 5), relay.ones, np.ones, "float32") + verify_any_full(any_dims(3), (225, 115, 15), relay.ones, np.ones, "float32") + verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones, np.ones, "int32") + verify_any_full(any_dims(3), (2, 3, 5), relay.ones_like, np.ones_like, "float32") + verify_any_full(any_dims(3), (225, 115, 15), relay.ones_like, np.ones_like, "float32") + verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones_like, np.ones_like, "int32") + def test_any_concat(): x = relay.var('x', shape=(relay.Any(), 2), dtype="float32") y = relay.var('y', shape=(1, 2), dtype="float32") @@ -85,10 +112,10 @@ def test_any_concat(): mod["main"] = relay.Function([x, y], z) x_np = np.random.uniform(size=(3, 2)).astype('float32') y_np = np.random.uniform(size=(1, 2)).astype('float32') + ref = np.concatenate([x_np, y_np], axis=0) for kind in ["debug", "vm"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") result = ex.evaluate()(x_np, y_np) - ref = np.concatenate([x_np, y_np], axis=0) tvm.testing.assert_allclose(result.asnumpy(), ref) def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape): @@ -116,10 +143,10 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): mod = relay.module.Module() mod["main"] = relay.Function([x], y) data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype) + expected = np.argwhere(data) for kind in ["debug", "vm"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") result = ex.evaluate()(data).asnumpy() - expected = np.argwhere(data) assert result.shape == expected.shape tvm.testing.assert_allclose(result.flatten(), expected.flatten()) @@ -412,10 +439,10 @@ def verify_any_pad(data_shape, pad_width, static_data_shape): y = relay.nn.pad(data, pad_width) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) + ref_out = np.pad(data_np, pad_width) for kind in ["debug", "vm"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") result = ex.evaluate()(data_np) - ref_out = np.pad(data_np, pad_width) tvm.testing.assert_allclose(result.asnumpy(), ref_out) def test_any_pad(): @@ -497,12 +524,12 @@ def _body(i, st): mod = relay.module.Module() mod["main"] = func data = np.array(0.0, dtype='int32') + ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32") # TODO(@jroesch): After LambdaLift pass, TypeInfer pass will fail # so currently we cannot run this test case on VM for kind in ["debug"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") result = ex.evaluate()(data) - ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32") np.testing.assert_allclose(result.asnumpy(), ref) def test_recursive_concat_with_wrong_annotation(): @@ -553,6 +580,7 @@ def _body(i, st): assert "in particular dimension 0 conflicts 2 does not match 1" in str(e) if __name__ == "__main__": + test_any_full() test_any_broadcast() test_any_broadcast_fail() test_any_concat()