diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index afeaee7e8f95..5ac0de4335f7 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -543,25 +543,23 @@ def _impl(inputs, attr, params): op_name="reshape", extras={'newshape':tuple(shape_arg.asnumpy())}, ignores=['Tshape'])(inputs, attr) - except KeyError: + except AttributeError: # Shape operator is already pruned, hence # try to infer shape by precompute prune if possible. - if all(in_node in params for in_node in inputs[1].list_input_names()): - func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) - with tvm.relay.build_config(opt_level=0): - graph, lib, params = tvm.relay.build(func, target="llvm", params=params) - ctx = tvm.context("llvm", 0) - from tvm.contrib import graph_runtime - m = graph_runtime.create(graph, lib, ctx) - m.set_input(**params) - m.run() - params_new = m.get_output(0) - inputs.pop(1) - return AttrCvt( - op_name="reshape", - extras={'newshape':tuple(params_new.asnumpy().flatten())}, - ignores=['Tshape'])(inputs, attr) - raise RuntimeError("Reshape with dynamic shape input not supported yet.") + func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) + with tvm.relay.build_config(opt_level=0): + graph, lib, params = tvm.relay.build(func, target="llvm", params=params) + ctx = tvm.context("llvm", 0) + from tvm.contrib import graph_runtime + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**params) + m.run() + params_new = m.get_output(0) + inputs.pop(1) + return AttrCvt( + op_name="reshape", + extras={'newshape':tuple(params_new.asnumpy().astype('int64').flatten())}, + ignores=['Tshape'])(inputs, attr) return _impl def _bias_add():