Skip to content

Commit

Permalink
[MergeComposite] Fix InferType when module contains Prelude (#5797)
Browse files Browse the repository at this point in the history
A function may refer to other resources in the same module, so keep
  the content of original module when infering a function.
  • Loading branch information
lixiaoquan committed Jun 16, 2020
1 parent 8931cfa commit 6ed8d7a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
5 changes: 4 additions & 1 deletion python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm import te
import tvm.relay as relay
import tvm.relay.op as op
from tvm.relay import Prelude


from . import mlp
Expand All @@ -44,9 +45,11 @@
from .py_converter import to_python, run_as_python
from ..transform import gradient

def run_opt_pass(expr, opt_pass):
def run_opt_pass(expr, opt_pass, import_prelude=False):
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
if import_prelude:
Prelude(mod)
mod = opt_pass(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
Expand Down
12 changes: 7 additions & 5 deletions src/relay/transforms/merge_composite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,24 @@ namespace tvm {
namespace relay {
namespace merge_composite {

Function InferType(const Function& expr) {
auto mod = IRModule::FromExpr(expr);
Function InferType(const Function& expr, const IRModule& m) {
IRModule mod(m);
mod->Update(mod->GetGlobalVar("main"), expr);
mod = transform::InferType()(mod);
return Downcast<Function>(mod->Lookup("main"));
}

Expr MergeComposite(const Function& func, const Array<runtime::String>& pattern_names,
const Array<DFPattern>& patterns, const std::vector<PackedFunc>& checks) {
const Array<DFPattern>& patterns, const std::vector<PackedFunc>& checks,
const IRModule& m) {
CHECK_EQ(pattern_names.size(), patterns.size());
Function merged_func = func;
// merge the patterns one-by-one in order
for (size_t i = 0; i < patterns.size(); i++) {
Map<String, ObjectRef> attrs;
attrs.Set("Composite", pattern_names[i]);
merged_func = Downcast<Function>(PartitionPattern(patterns[i], merged_func, attrs, checks[i]));
merged_func = InferType(merged_func);
merged_func = InferType(merged_func, m);
}
return std::move(merged_func);
}
Expand All @@ -65,7 +67,7 @@ Pass MergeComposite(const tvm::Array<runtime::String>& pattern_names,
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(
relay::merge_composite::MergeComposite(f, pattern_names, patterns, checks));
relay::merge_composite::MergeComposite(f, pattern_names, patterns, checks, m));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {});
return func_pass;
Expand Down
6 changes: 3 additions & 3 deletions tests/python/relay/test_pass_merge_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def make_bn_relu_pattern():
r = is_op('nn.relu')(tuple_get_item_node)
return r

def check_result(pattern_table, graph, expected_graph):
def check_result(pattern_table, graph, expected_graph, import_prelude=False):
"""Utility function to check merge composite results."""
result = run_opt_pass(graph, relay.transform.MergeComposite(pattern_table))
result = run_opt_pass(graph, relay.transform.MergeComposite(pattern_table), import_prelude=import_prelude)
assert not relay.analysis.free_vars(result), \
"Found free vars in the result graph: {0}".format(str(result))
expected = run_opt_pass(expected_graph, relay.transform.InferType())
Expand Down Expand Up @@ -213,7 +213,7 @@ def expected():
r = relay.Call(add_relu, [a, b])
return relay.Function([a, b], r)

check_result(pattern_table, before(), expected())
check_result(pattern_table, before(), expected(), import_prelude=True)


def test_branch_merge():
Expand Down

0 comments on commit 6ed8d7a

Please sign in to comment.