From 6ed8d7a8c2a9d1c3ca104400f152ab8e97556705 Mon Sep 17 00:00:00 2001 From: lixiaoquan Date: Wed, 17 Jun 2020 06:11:22 +0800 Subject: [PATCH] [MergeComposite] Fix InferType when module contains Prelude (#5797) A function may refer to other resources in the same module, so keep the content of original module when infering a function. --- python/tvm/relay/testing/__init__.py | 5 ++++- src/relay/transforms/merge_composite.cc | 12 +++++++----- tests/python/relay/test_pass_merge_composite.py | 6 +++--- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index b8ef906b81ad..8310a0202c17 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -23,6 +23,7 @@ from tvm import te import tvm.relay as relay import tvm.relay.op as op +from tvm.relay import Prelude from . import mlp @@ -44,9 +45,11 @@ from .py_converter import to_python, run_as_python from ..transform import gradient -def run_opt_pass(expr, opt_pass): +def run_opt_pass(expr, opt_pass, import_prelude=False): assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) + if import_prelude: + Prelude(mod) mod = opt_pass(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc index 324b2cb3a1c4..7e7ad0e665a7 100644 --- a/src/relay/transforms/merge_composite.cc +++ b/src/relay/transforms/merge_composite.cc @@ -36,14 +36,16 @@ namespace tvm { namespace relay { namespace merge_composite { -Function InferType(const Function& expr) { - auto mod = IRModule::FromExpr(expr); +Function InferType(const Function& expr, const IRModule& m) { + IRModule mod(m); + mod->Update(mod->GetGlobalVar("main"), expr); mod = transform::InferType()(mod); return Downcast(mod->Lookup("main")); } Expr MergeComposite(const Function& func, const Array& pattern_names, - const Array& patterns, const std::vector& checks) { + const Array& patterns, const std::vector& checks, + const IRModule& m) { CHECK_EQ(pattern_names.size(), patterns.size()); Function merged_func = func; // merge the patterns one-by-one in order @@ -51,7 +53,7 @@ Expr MergeComposite(const Function& func, const Array& pattern_ Map attrs; attrs.Set("Composite", pattern_names[i]); merged_func = Downcast(PartitionPattern(patterns[i], merged_func, attrs, checks[i])); - merged_func = InferType(merged_func); + merged_func = InferType(merged_func, m); } return std::move(merged_func); } @@ -65,7 +67,7 @@ Pass MergeComposite(const tvm::Array& pattern_names, runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast( - relay::merge_composite::MergeComposite(f, pattern_names, patterns, checks)); + relay::merge_composite::MergeComposite(f, pattern_names, patterns, checks, m)); }; auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {}); return func_pass; diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index f2d615e9046a..ddb5b5dab675 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -163,9 +163,9 @@ def make_bn_relu_pattern(): r = is_op('nn.relu')(tuple_get_item_node) return r -def check_result(pattern_table, graph, expected_graph): +def check_result(pattern_table, graph, expected_graph, import_prelude=False): """Utility function to check merge composite results.""" - result = run_opt_pass(graph, relay.transform.MergeComposite(pattern_table)) + result = run_opt_pass(graph, relay.transform.MergeComposite(pattern_table), import_prelude=import_prelude) assert not relay.analysis.free_vars(result), \ "Found free vars in the result graph: {0}".format(str(result)) expected = run_opt_pass(expected_graph, relay.transform.InferType()) @@ -213,7 +213,7 @@ def expected(): r = relay.Call(add_relu, [a, b]) return relay.Function([a, b], r) - check_result(pattern_table, before(), expected()) + check_result(pattern_table, before(), expected(), import_prelude=True) def test_branch_merge():