Skip to content

Commit

Permalink
Merge pull request apache#5 from cmu-catalyst/dp-fused-pass
Browse files Browse the repository at this point in the history
[DP fusion pass] Fix the issue of not assigning backend op to kOpaque…
  • Loading branch information
MadFunMaker authored Apr 19, 2021
2 parents ac83ecb + 5bf9c41 commit 7e5da7d
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ namespace tvm {
// graph_.exprnode_to_backend_op[key].debug_print();
it++;
}

}

// void VisitExpr(const Expr& expr) {
Expand Down Expand Up @@ -818,19 +817,19 @@ namespace tvm {

// std::cerr << "Group node (" << nid << ") pattern: " << group_node->pattern << std::endl;

// Note that Var or Constant will be filtered out by this.
if (group_node->pattern == kOpaque) continue;

// WARNING(@Soo): We assume that fused ops are always not opaque.
// no actions for opaque nodes
// Assign backend op name
// WARNING(@Soo): We should assume that fused ops are not always opaque.
const tvm::Object* cur_key = graph_node->ref;
assert (graph.exprnode_to_backend_op.find(cur_key) != graph.exprnode_to_backend_op.end());
GroupIdOpNamePair pair_info = graph.exprnode_to_backend_op.at(cur_key);
int cur_group_id = pair_info.group_id;

group_node->backend_op_name = pair_info.backend_op_name;

// Note that Var or Constant will be filtered out by this.
// Softmax is also kOpaque
if (group_node->pattern == kOpaque) continue;

// Get group id for cur and prev node
int cur_group_id = pair_info.group_id;
auto* prev_graph_node = graph.post_dfs_order[prev_nid];
if (cur_group_id == prev_group_id) {
// std::cerr << "cur, pre nid: " << nid << " / " << prev_nid << std::endl;
Expand Down Expand Up @@ -1114,6 +1113,8 @@ namespace tvm {

// PATCH(@Soo): Add backend op attribute.
func = WithAttr(std::move(func), attr::kBackendOp, String(group->backend_op_name));
// std::cerr << "Func: " << func << std::endl;
// std::cerr << "Backend op name: " << group->backend_op_name << std::endl;
return Call(func, ginfo.arguments, Attrs());
}

Expand Down

0 comments on commit 7e5da7d

Please sign in to comment.