Skip to content

Commit

Permalink
fix fuse over functions that are handled by external codegen (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and dhruvaray committed Apr 28, 2020
1 parent 7f995ce commit 29da9ec
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
14 changes: 7 additions & 7 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -924,20 +924,20 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());

// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());

// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
// and we use these ops to invoke the symbols in the module generated by
// external codegen.
pass_seqs.push_back(transform::Inline());

// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());

// Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));

Expand Down
3 changes: 3 additions & 0 deletions src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ class IndexedForwardGraph::Creator : private ExprVisitor {

// Post order tree
void VisitExpr_(const FunctionNode* op) final {
// Skip the function that should be handled by external codegen.
if (op->GetAttr<String>(attr::kCompiler).defined()) return;

for (auto param : op->params) {
this->Update(param, nullptr, kOpaque);
}
Expand Down
1 change: 0 additions & 1 deletion tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,6 @@ def test_extern_dnnl_mobilenet():
mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32')

mod["main"] = bind_params_by_name(mod["main"], params)
mod = transform.AnnotateTarget(["dnnl"])(mod)
mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod)
Expand Down

0 comments on commit 29da9ec

Please sign in to comment.