From 38dee41906962b8f2c1eb86e85df134e967d7fad Mon Sep 17 00:00:00 2001 From: Rohan Mukherjee Date: Wed, 28 Oct 2020 15:32:35 -0500 Subject: [PATCH] [ManifestAlloc] Handle TupleType inputs in CheckReshapeOnly (#6776) * Changes in CheckReshapeOnly to support TupleTypes as input This arises insed ManifestAllocPass inside relay.vm.compile * [ManifestAlloc] Handle TupleType inputs in CheckReshapeOnly --- python/tvm/relay/transform/memory_alloc.py | 5 +++++ tests/python/relay/test_vm.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index f611c1cc14c1..66528c861788 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -84,6 +84,11 @@ def visit_call(self, call): for arg in call.args: self.visit(arg) + def visit_var(self, var): + var_type = var.checked_type + if not isinstance(var_type, ty.TensorType): + self.reshape_only = False + def is_reshape_only(func): """Check if the primitive function contains only reshape ops.""" diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 038b5c5ed9e1..92d6e8e55db4 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -754,5 +754,21 @@ def test_vm_reshape_tensor(): check_result([x_np, y_np], x_np.reshape([8, 2, 8]), mod) +def test_vm_reshape_tuple(x_shape=(1, 4, 2), y_shape=(1, 2, 10)): + tup = relay.var( + "tup", + type_annotation=relay.TupleType([relay.TensorType(x_shape), relay.TensorType(y_shape)]), + ) + out = relay.reshape(relay.TupleGetItem(tup, 0), (1, -1)) + f = relay.Function([tup], out) + + x_data = np.random.uniform(size=x_shape).astype("float32") + y_data = np.random.uniform(size=y_shape).astype("float32") + + for tgt, ctx in tvm.testing.enabled_targets(): + res = veval(f, (x_data, y_data), ctx=ctx, target=tgt) + tvm.testing.assert_allclose(res.asnumpy(), np.reshape(x_data, (1, -1))) + + if __name__ == "__main__": pytest.main([__file__])