diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 52a501c4cd33a..e24765529542f 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -34,7 +34,7 @@ from .. import function as _function from .. import op as _op from ..ty import Any -from ..expr_functor import ExprVisitor +from ..expr_functor import ExprMutator, ExprVisitor from .common import AttrCvt, get_relay_op from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape @@ -2416,6 +2416,26 @@ def _in_while_loop(control_flow_node_map, op_name): return op_name in control_flow_node_map and \ "LoopCond" in control_flow_node_map[op_name] +class RewriteSubgraph(ExprMutator): + """ + A helper class to rewrite expr in while loop function to variable + Parameters + ---------- + rewrite_map : Dict[expr, expr] + A dictionay contains a set of expr to var mapping. + """ + def __init__(self, rewrite_map): + ExprMutator.__init__(self) + self.rewrite_map = rewrite_map + + def visit(self, expr): + if expr in self.rewrite_map: + return self.rewrite_map[expr] + return super().visit(expr) + +def rewrite_subgraph(expr, rewrites): + return RewriteSubgraph(rewrites).visit(expr) + class Branch: """A class contains the components that are used to build up a Relay if node. @@ -2576,18 +2596,39 @@ def _while_loop(self): """An internal API to create a Relay recursive call for a matched TF `while_loop` construct. """ + bind_map = {} wl = tvm.relay.var('while_loop') sb = tvm.relay.scope_builder.ScopeBuilder() - cond = tvm.relay.op.min(self.cond) lv_list = [] expr_list = [] extra_vars = [] - for lv in self.loop_vars: + for i, lv in enumerate(self.loop_vars): + if self._loop_name not in self._lvar2expr: + self._lvar2expr[self._loop_name] = {} + + # Handle the case when loop var is not properly lifted. + # This can happen when loop var node name is set accidentally + # beginning with loop name. + if lv not in self._lvar2expr[self._loop_name]: + var_name = "{}_loop_var_{}".format(self._loop_name, i) + var_type = _infer_type(lv, self._mod).checked_type + loop_var = tvm.relay.var(var_name, type_annotation=var_type) + self._lvar2expr[self._loop_name][loop_var] = lv + bind_map[lv] = loop_var + self.loop_vars[i] = loop_var + lv = loop_var + lv_list.append(lv) expr_list.append(self._lvar2expr[self._loop_name][lv]) + if bind_map: + self.cond = rewrite_subgraph(self.cond, bind_map) + self.body = [rewrite_subgraph(b, bind_map) for b in self.body] + + cond = tvm.relay.op.min(self.cond) + for lv, exp in self._lvar2expr[self._loop_name].items(): if lv not in self.loop_vars: var_checker = VarChecker(lv) @@ -3211,8 +3252,8 @@ def _licm_construct(self, loop_name, node_name): Converted relay expression or loop var. """ actual_expr = self._backtrack_construct(node_name) - tn = node_name.split(':').split("^")[-1] - node_name = tn[0] + tn = node_name.split(':') + node_name = tn[0].split("^")[-1] cloop_name = find_parent_loop_name(node_name, self._while_loop_name_set) if loop_name in self._while_loop_name_set and not cloop_name.startswith(loop_name): diff --git a/tests/python/frontend/tensorflow/test_control_flow.py b/tests/python/frontend/tensorflow/test_control_flow.py index f00018de299f9..3ec04bf384909 100644 --- a/tests/python/frontend/tensorflow/test_control_flow.py +++ b/tests/python/frontend/tensorflow/test_control_flow.py @@ -46,7 +46,7 @@ def check_equal(graph, tf_out, input_map=None): def test_vanilla_loop(): graph = tf.Graph() with graph.as_default(): - i = tf.constant(0) + i = tf.constant(0, name="while/constant") def c(i): return tf.less(i, 10)