Skip to content

Commit

Permalink
Fix loop var naming issue
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun committed Jun 5, 2020
1 parent 39964b9 commit 3428fcb
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
51 changes: 46 additions & 5 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .. import function as _function
from .. import op as _op
from ..ty import Any
from ..expr_functor import ExprVisitor
from ..expr_functor import ExprMutator, ExprVisitor
from .common import AttrCvt, get_relay_op
from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape
Expand Down Expand Up @@ -2416,6 +2416,26 @@ def _in_while_loop(control_flow_node_map, op_name):
return op_name in control_flow_node_map and \
"LoopCond" in control_flow_node_map[op_name]

class RewriteSubgraph(ExprMutator):
"""
A helper class to rewrite expr in while loop function to variable
Parameters
----------
rewrite_map : Dict[expr, expr]
A dictionay contains a set of expr to var mapping.
"""
def __init__(self, rewrite_map):
ExprMutator.__init__(self)
self.rewrite_map = rewrite_map

def visit(self, expr):
if expr in self.rewrite_map:
return self.rewrite_map[expr]
return super().visit(expr)

def rewrite_subgraph(expr, rewrites):
return RewriteSubgraph(rewrites).visit(expr)

class Branch:
"""A class contains the components that are used to build up a Relay if
node.
Expand Down Expand Up @@ -2576,18 +2596,39 @@ def _while_loop(self):
"""An internal API to create a Relay recursive call for a matched TF
`while_loop` construct.
"""
bind_map = {}
wl = tvm.relay.var('while_loop')
sb = tvm.relay.scope_builder.ScopeBuilder()
cond = tvm.relay.op.min(self.cond)

lv_list = []
expr_list = []
extra_vars = []

for lv in self.loop_vars:
for i, lv in enumerate(self.loop_vars):
if self._loop_name not in self._lvar2expr:
self._lvar2expr[self._loop_name] = {}

# Handle the case when loop var is not properly lifted.
# This can happen when loop var node name is set accidentally
# beginning with loop name.
if lv not in self._lvar2expr[self._loop_name]:
var_name = "{}_loop_var_{}".format(self._loop_name, i)
var_type = _infer_type(lv, self._mod).checked_type
loop_var = tvm.relay.var(var_name, type_annotation=var_type)
self._lvar2expr[self._loop_name][loop_var] = lv
bind_map[lv] = loop_var
self.loop_vars[i] = loop_var
lv = loop_var

lv_list.append(lv)
expr_list.append(self._lvar2expr[self._loop_name][lv])

if bind_map:
self.cond = rewrite_subgraph(self.cond, bind_map)
self.body = [rewrite_subgraph(b, bind_map) for b in self.body]

cond = tvm.relay.op.min(self.cond)

for lv, exp in self._lvar2expr[self._loop_name].items():
if lv not in self.loop_vars:
var_checker = VarChecker(lv)
Expand Down Expand Up @@ -3211,8 +3252,8 @@ def _licm_construct(self, loop_name, node_name):
Converted relay expression or loop var.
"""
actual_expr = self._backtrack_construct(node_name)
tn = node_name.split(':').split("^")[-1]
node_name = tn[0]
tn = node_name.split(':')
node_name = tn[0].split("^")[-1]
cloop_name = find_parent_loop_name(node_name, self._while_loop_name_set)

if loop_name in self._while_loop_name_set and not cloop_name.startswith(loop_name):
Expand Down
2 changes: 1 addition & 1 deletion tests/python/frontend/tensorflow/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def check_equal(graph, tf_out, input_map=None):
def test_vanilla_loop():
graph = tf.Graph()
with graph.as_default():
i = tf.constant(0)
i = tf.constant(0, name="while/constant")

def c(i): return tf.less(i, 10)

Expand Down

0 comments on commit 3428fcb

Please sign in to comment.