From 40add301dd4b0e5271fa666cbee04c79d1da9ace Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 28 Nov 2018 23:42:00 +0900 Subject: [PATCH 01/10] Add support for tuple node in op fusion --- src/relay/pass/fuse_ops.cc | 45 +++++++++++++++--------- tests/python/relay/test_pass_fuse_ops.py | 39 ++++++++++++++++++-- 2 files changed, 66 insertions(+), 18 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index b9e0823e88fa..9b41c82cec49 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -232,8 +232,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } void VisitExpr_(const TupleNode* op) { + CHECK(graph_.node_map.count(op)); + Node* tuple_node = graph_.node_map.at(op); + tuple_node->pattern = kInjective; for (const Expr& field : op->fields) { - this->Update(field, nullptr, kOpaque); + this->Update(field, tuple_node, kInjective); } ExprVisitor::VisitExpr_(op); this->AddNode(op); @@ -712,21 +715,8 @@ class FuseMutator : private ExprMutator { // then we must have a group assignment for it already. CHECK(gmap_.count(call)); auto* ret_group = gmap_.at(call)->FindRoot(); - Array new_args; - for (auto arg : call->args) { - auto type = arg->checked_type(); - CHECK(gmap_.count(arg.get())) - << "cannot find group of " << arg; - auto* arg_group = gmap_.at(arg.get())->FindRoot(); - Expr new_arg = this->Mutate(arg); - - if (ret_group != arg_group) { - Var param = ginfo_[ret_group].GetOrAllocParam(new_arg, type); - new_args.push_back(param); - } else { - new_args.push_back(new_arg); - } - } + Array new_args = GetNewArguments(call->args, ret_group); + auto new_call = CallNode::make( call->op, new_args, call->attrs, call->type_args); @@ -747,6 +737,29 @@ class FuseMutator : private ExprMutator { return ExprMutator::VisitExpr_(call); } } + + Expr VisitExpr_(const TupleNode* tuple) { + auto* ret_group = gmap_.at(tuple)->FindRoot(); + Array new_fields = GetNewArguments(tuple->fields, ret_group); + return TupleNode::make(new_fields); + } + + Array GetNewArguments(const tvm::Array& args, GraphPartitioner::Group* current_group) { + Array new_args; + for (auto arg : args) { + auto* arg_group = gmap_.at(arg.get())->FindRoot(); + auto type = arg->checked_type(); + Expr new_arg = this->Mutate(arg); + if (current_group != arg_group) { + Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type); + new_args.push_back(param); + } else { + new_args.push_back(new_arg); + } + } + return new_args; + } + // Debug function, dump the group assignment in text. void DebugDumpGroup(const Expr& body) { std::string text = RelayPrint(body, false, [this](const Expr& expr) -> std::string { diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 27806791c399..f691bb88fa08 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -28,8 +28,6 @@ def expected(): assert relay.ir_pass.alpha_equal(zz, after) - - def test_conv2d_fuse(): """Test fusion case of conv2d""" def before(dshape): @@ -106,7 +104,44 @@ def expected(dshape): assert relay.ir_pass.alpha_equal(zz, after) +def test_concatenate(): + """Test fusion case involving concat op and Tuple node""" + + def before(dshape): + x = relay.var("x", shape=dshape) + pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW") + concat = relay.concatenate((upsampled, x), axis=1) + out = relay.add(concat, relay.const(1, "float32")) + return relay.Function(relay.ir_pass.free_vars(out), out) + + def expected(dshape): + x = relay.var("x", shape=dshape) + pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + f0 = relay.Function([x], pooled) + + p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) + p1 = relay.var("p1", shape=dshape) + upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW") + concat = relay.concatenate((upsampled, p1), axis=1) + out = relay.add(concat, relay.const(1, "float32")) + f1 = relay.Function([p0, p1], out) + + x = relay.var("x", shape=dshape) + y = relay.Call(f0, [x]) + z = relay.Call(f1, [y, x]) + return relay.Function([x], z) + + dshape = (1, 16, 64, 64) + z = before(dshape) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + zz = relay.ir_pass.infer_type(zz) + print(zz.astext()) + after = relay.ir_pass.infer_type(expected(dshape)) + assert relay.ir_pass.alpha_equal(zz, after) if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() + test_concatenate() From 3e58e06874f5b4d47783b9bee5db79c8070bea0c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 28 Nov 2018 23:49:20 +0900 Subject: [PATCH 02/10] remove print --- tests/python/relay/test_pass_fuse_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index f691bb88fa08..4e4355662eff 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -137,7 +137,6 @@ def expected(dshape): z = relay.ir_pass.infer_type(z) zz = relay.ir_pass.fuse_ops(z, opt_level=2) zz = relay.ir_pass.infer_type(zz) - print(zz.astext()) after = relay.ir_pass.infer_type(expected(dshape)) assert relay.ir_pass.alpha_equal(zz, after) From 88afe288c770d3cfe4b407133dcbdb2ff158593e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 29 Nov 2018 00:00:12 +0900 Subject: [PATCH 03/10] fix lint --- src/relay/pass/fuse_ops.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 9b41c82cec49..2974bf601a3e 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -744,7 +744,8 @@ class FuseMutator : private ExprMutator { return TupleNode::make(new_fields); } - Array GetNewArguments(const tvm::Array& args, GraphPartitioner::Group* current_group) { + Array GetNewArguments(const tvm::Array& args, + GraphPartitioner::Group* current_group) { Array new_args; for (auto arg : args) { auto* arg_group = gmap_.at(arg.get())->FindRoot(); From bb42eff20f86fa251df55cee4cbe858af4a74980 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 29 Nov 2018 12:52:26 +0900 Subject: [PATCH 04/10] handle opt level = 0 and tuple root --- src/relay/pass/fuse_ops.cc | 32 ++++++++++++++---- tests/python/relay/test_pass_fuse_ops.py | 43 ++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 6 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 2974bf601a3e..b01b69486455 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -659,9 +659,14 @@ class FuseMutator : private ExprMutator { auto graph = IndexedForwardGraph::Create(&arena_, body); auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition( graph); + for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { + // Make sure to init counts + node_count_[groups[nid]] = 0; + } for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { CHECK(graph.post_dfs_order[nid]->ref != nullptr); gmap_[graph.post_dfs_order[nid]->ref] = groups[nid]; + ++node_count_[groups[nid]->FindRoot()]; } // The following line can be used for debug. // this->DebugDumpGroup(body); @@ -698,6 +703,9 @@ class FuseMutator : private ExprMutator { std::unordered_map gmap_; /* \brief Internal group information map. */ std::unordered_map ginfo_; + /* \brief Counts the number of nodes in a group*/ + std::unordered_map node_count_; + // Skip primitive function. Expr VisitExpr_(const FunctionNode* fn_node) { NodeRef res = FunctionGetAttr(GetRef(fn_node), "Primitive"); @@ -723,11 +731,7 @@ class FuseMutator : private ExprMutator { if (ret_group->root_ref == call) { // This is the root of the group // create the new call node. - const GroupInfo& ginfo = ginfo_[ret_group]; - auto func = FunctionNode::make( - ginfo.params, new_call, call->checked_type(), {}); - func = FunctionSetAttr(func, "Primitive", tvm::Integer(1)); - return CallNode::make(func, ginfo.arguments, Attrs()); + return MakeNewFunction(ret_group, call->checked_type(), new_call); } else { // This is an intermediate node of a fused function // simply return the new call. @@ -741,7 +745,23 @@ class FuseMutator : private ExprMutator { Expr VisitExpr_(const TupleNode* tuple) { auto* ret_group = gmap_.at(tuple)->FindRoot(); Array new_fields = GetNewArguments(tuple->fields, ret_group); - return TupleNode::make(new_fields); + Tuple new_tuple = TupleNode::make(new_fields); + if (ret_group == gmap_.at(tuple)) { + if (node_count_[ret_group] == 1) { + // Do not put a isolated tuple into a function + return ExprMutator::VisitExpr_(tuple); + } + // This tuple has been fused other ops before it + return MakeNewFunction(ret_group, TupleType(), new_tuple); + } + return new_tuple; + } + + Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { + const GroupInfo& ginfo = ginfo_[group]; + auto func = FunctionNode::make(ginfo.params, body, ret_type, {}); + func = FunctionSetAttr(func, "Primitive", tvm::Integer(1)); + return CallNode::make(func, ginfo.arguments, Attrs()); } Array GetNewArguments(const tvm::Array& args, diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 4e4355662eff..823c44b9676d 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -135,12 +135,55 @@ def expected(dshape): dshape = (1, 16, 64, 64) z = before(dshape) z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) zz = relay.ir_pass.fuse_ops(z, opt_level=2) zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) after = relay.ir_pass.infer_type(expected(dshape)) assert relay.ir_pass.alpha_equal(zz, after) + +def test_tuple_root(): + """Test fusion case involving concat op and Tuple node""" + + def before(dshape): + x = relay.var("x", shape=dshape) + pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW") + out = relay.Tuple((upsampled, x)) + return relay.Function(relay.ir_pass.free_vars(out), out) + + def expected(dshape): + x = relay.var("x", shape=dshape) + pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + f0 = relay.Function([x], pooled) + + p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) + p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3])) + upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW") + out = relay.Tuple((upsampled, p1)) + f1 = relay.Function([p0, p1], out) + + x = relay.var("x", shape=dshape) + y = relay.Call(f0, [x]) + z = relay.Call(f1, [y, x]) + return relay.Function([x], z) + + dshape = (1, 16, 64, 64) + z = before(dshape) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected(dshape)) + assert relay.ir_pass.alpha_equal(zz, after) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() test_concatenate() + test_tuple_root() From f80398f9e6a119040819fc76e152387064c978a5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 29 Nov 2018 12:58:35 +0900 Subject: [PATCH 05/10] fix lint --- src/relay/pass/fuse_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index b01b69486455..f493b44e903d 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -754,7 +754,7 @@ class FuseMutator : private ExprMutator { // This tuple has been fused other ops before it return MakeNewFunction(ret_group, TupleType(), new_tuple); } - return new_tuple; + return new_tuple; } Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { From b48469aebc7ec8bfad1c67bfd61a7732de7605be Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 29 Nov 2018 13:08:59 +0900 Subject: [PATCH 06/10] fix tuple function return type --- src/relay/pass/fuse_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index f493b44e903d..01344581c383 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -752,7 +752,7 @@ class FuseMutator : private ExprMutator { return ExprMutator::VisitExpr_(tuple); } // This tuple has been fused other ops before it - return MakeNewFunction(ret_group, TupleType(), new_tuple); + return MakeNewFunction(ret_group, tuple->checked_type(), new_tuple); } return new_tuple; } From 7d017c5374c57b5c952165c255efd1c14ae44fb7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 29 Nov 2018 15:46:19 +0900 Subject: [PATCH 07/10] remove node count --- src/relay/pass/fuse_ops.cc | 17 +++++++---------- tests/python/relay/test_pass_fuse_ops.py | 2 +- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 01344581c383..595912381afb 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -659,14 +659,9 @@ class FuseMutator : private ExprMutator { auto graph = IndexedForwardGraph::Create(&arena_, body); auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition( graph); - for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { - // Make sure to init counts - node_count_[groups[nid]] = 0; - } for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { CHECK(graph.post_dfs_order[nid]->ref != nullptr); gmap_[graph.post_dfs_order[nid]->ref] = groups[nid]; - ++node_count_[groups[nid]->FindRoot()]; } // The following line can be used for debug. // this->DebugDumpGroup(body); @@ -703,9 +698,6 @@ class FuseMutator : private ExprMutator { std::unordered_map gmap_; /* \brief Internal group information map. */ std::unordered_map ginfo_; - /* \brief Counts the number of nodes in a group*/ - std::unordered_map node_count_; - // Skip primitive function. Expr VisitExpr_(const FunctionNode* fn_node) { NodeRef res = FunctionGetAttr(GetRef(fn_node), "Primitive"); @@ -747,13 +739,18 @@ class FuseMutator : private ExprMutator { Array new_fields = GetNewArguments(tuple->fields, ret_group); Tuple new_tuple = TupleNode::make(new_fields); if (ret_group == gmap_.at(tuple)) { - if (node_count_[ret_group] == 1) { + bool isolated = true; + for (int i = 0; i < new_fields.size(); ++i) { + isolated &= (new_fields[i] == ginfo_[ret_group].params[i]); + } + if (isolated) { // Do not put a isolated tuple into a function return ExprMutator::VisitExpr_(tuple); } - // This tuple has been fused other ops before it + // This tuple has been fused with other ops before it return MakeNewFunction(ret_group, tuple->checked_type(), new_tuple); } + // This tuple is an intermediate node in the group return new_tuple; } diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 823c44b9676d..28ea8dd28988 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -145,7 +145,7 @@ def expected(dshape): def test_tuple_root(): - """Test fusion case involving concat op and Tuple node""" + """Test fusion case where Tuple node is the root in its group""" def before(dshape): x = relay.var("x", shape=dshape) From 6b3ea18a6f868002b141e3940edde05003cd6a2f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 29 Nov 2018 15:49:25 +0900 Subject: [PATCH 08/10] fix lint --- src/relay/pass/fuse_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 595912381afb..e5490b3777ee 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -740,7 +740,7 @@ class FuseMutator : private ExprMutator { Tuple new_tuple = TupleNode::make(new_fields); if (ret_group == gmap_.at(tuple)) { bool isolated = true; - for (int i = 0; i < new_fields.size(); ++i) { + for (int i = 0; i < new_fields.size(); ++i) { isolated &= (new_fields[i] == ginfo_[ret_group].params[i]); } if (isolated) { From baf41568c415dce6e779eb40bb0b6fe605971028 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 29 Nov 2018 15:51:48 +0900 Subject: [PATCH 09/10] use size_t --- src/relay/pass/fuse_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index e5490b3777ee..9daa63145620 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -740,7 +740,7 @@ class FuseMutator : private ExprMutator { Tuple new_tuple = TupleNode::make(new_fields); if (ret_group == gmap_.at(tuple)) { bool isolated = true; - for (int i = 0; i < new_fields.size(); ++i) { + for (size_t i = 0; i < new_fields.size(); ++i) { isolated &= (new_fields[i] == ginfo_[ret_group].params[i]); } if (isolated) { From 9d0925d6df80e3858104ef9208c0d9a4b02b0962 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 30 Nov 2018 07:58:57 +0900 Subject: [PATCH 10/10] use same_as --- src/relay/pass/fuse_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 9daa63145620..21660decf2fa 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -741,7 +741,7 @@ class FuseMutator : private ExprMutator { if (ret_group == gmap_.at(tuple)) { bool isolated = true; for (size_t i = 0; i < new_fields.size(); ++i) { - isolated &= (new_fields[i] == ginfo_[ret_group].params[i]); + isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i])); } if (isolated) { // Do not put a isolated tuple into a function