Skip to content

Commit

Permalink
[Relay] Add support for tuple node in operator fusion (apache#2187)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored and Wei Chen committed Feb 20, 2019
1 parent f7bd2b1 commit 10d9534
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 23 deletions.
73 changes: 52 additions & 21 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
}

void VisitExpr_(const TupleNode* op) {
CHECK(graph_.node_map.count(op));
Node* tuple_node = graph_.node_map.at(op);
tuple_node->pattern = kInjective;
for (const Expr& field : op->fields) {
this->Update(field, nullptr, kOpaque);
this->Update(field, tuple_node, kInjective);
}
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
Expand Down Expand Up @@ -712,32 +715,15 @@ class FuseMutator : private ExprMutator {
// then we must have a group assignment for it already.
CHECK(gmap_.count(call));
auto* ret_group = gmap_.at(call)->FindRoot();
Array<Expr> new_args;
for (auto arg : call->args) {
auto type = arg->checked_type();
CHECK(gmap_.count(arg.get()))
<< "cannot find group of " << arg;
auto* arg_group = gmap_.at(arg.get())->FindRoot();
Expr new_arg = this->Mutate(arg);

if (ret_group != arg_group) {
Var param = ginfo_[ret_group].GetOrAllocParam(new_arg, type);
new_args.push_back(param);
} else {
new_args.push_back(new_arg);
}
}
Array<Expr> new_args = GetNewArguments(call->args, ret_group);

auto new_call = CallNode::make(
call->op, new_args, call->attrs, call->type_args);

if (ret_group->root_ref == call) {
// This is the root of the group
// create the new call node.
const GroupInfo& ginfo = ginfo_[ret_group];
auto func = FunctionNode::make(
ginfo.params, new_call, call->checked_type(), {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
return CallNode::make(func, ginfo.arguments, Attrs());
return MakeNewFunction(ret_group, call->checked_type(), new_call);
} else {
// This is an intermediate node of a fused function
// simply return the new call.
Expand All @@ -747,6 +733,51 @@ class FuseMutator : private ExprMutator {
return ExprMutator::VisitExpr_(call);
}
}

Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
Tuple new_tuple = TupleNode::make(new_fields);
if (ret_group == gmap_.at(tuple)) {
bool isolated = true;
for (size_t i = 0; i < new_fields.size(); ++i) {
isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i]));
}
if (isolated) {
// Do not put a isolated tuple into a function
return ExprMutator::VisitExpr_(tuple);
}
// This tuple has been fused with other ops before it
return MakeNewFunction(ret_group, tuple->checked_type(), new_tuple);
}
// This tuple is an intermediate node in the group
return new_tuple;
}

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
return CallNode::make(func, ginfo.arguments, Attrs());
}

Array<Expr> GetNewArguments(const tvm::Array<Expr>& args,
GraphPartitioner::Group* current_group) {
Array<Expr> new_args;
for (auto arg : args) {
auto* arg_group = gmap_.at(arg.get())->FindRoot();
auto type = arg->checked_type();
Expr new_arg = this->Mutate(arg);
if (current_group != arg_group) {
Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type);
new_args.push_back(param);
} else {
new_args.push_back(new_arg);
}
}
return new_args;
}

// Debug function, dump the group assignment in text.
void DebugDumpGroup(const Expr& body) {
std::string text = RelayPrint(body, false, [this](const Expr& expr) -> std::string {
Expand Down
81 changes: 79 additions & 2 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ def expected():
assert relay.ir_pass.alpha_equal(zz, after)




def test_conv2d_fuse():
"""Test fusion case of conv2d"""
def before(dshape):
Expand Down Expand Up @@ -106,7 +104,86 @@ def expected(dshape):
assert relay.ir_pass.alpha_equal(zz, after)


def test_concatenate():
"""Test fusion case involving concat op and Tuple node"""

def before(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW")
concat = relay.concatenate((upsampled, x), axis=1)
out = relay.add(concat, relay.const(1, "float32"))
return relay.Function(relay.ir_pass.free_vars(out), out)

def expected(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
f0 = relay.Function([x], pooled)

p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=dshape)
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
concat = relay.concatenate((upsampled, p1), axis=1)
out = relay.add(concat, relay.const(1, "float32"))
f1 = relay.Function([p0, p1], out)

x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
z = relay.Call(f1, [y, x])
return relay.Function([x], z)

dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)


def test_tuple_root():
"""Test fusion case where Tuple node is the root in its group"""

def before(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW")
out = relay.Tuple((upsampled, x))
return relay.Function(relay.ir_pass.free_vars(out), out)

def expected(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
f0 = relay.Function([x], pooled)

p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3]))
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
out = relay.Tuple((upsampled, p1))
f1 = relay.Function([p0, p1], out)

x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
z = relay.Call(f1, [y, x])
return relay.Function([x], z)

dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
test_concatenate()
test_tuple_root()

0 comments on commit 10d9534

Please sign in to comment.