Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Add support for tuple node in operator fusion #2187

Merged
merged 10 commits into from
Nov 30, 2018
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 52 additions & 21 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
}

void VisitExpr_(const TupleNode* op) {
CHECK(graph_.node_map.count(op));
Node* tuple_node = graph_.node_map.at(op);
tuple_node->pattern = kInjective;
for (const Expr& field : op->fields) {
this->Update(field, nullptr, kOpaque);
this->Update(field, tuple_node, kInjective);
}
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
Expand Down Expand Up @@ -712,32 +715,15 @@ class FuseMutator : private ExprMutator {
// then we must have a group assignment for it already.
CHECK(gmap_.count(call));
auto* ret_group = gmap_.at(call)->FindRoot();
Array<Expr> new_args;
for (auto arg : call->args) {
auto type = arg->checked_type();
CHECK(gmap_.count(arg.get()))
<< "cannot find group of " << arg;
auto* arg_group = gmap_.at(arg.get())->FindRoot();
Expr new_arg = this->Mutate(arg);

if (ret_group != arg_group) {
Var param = ginfo_[ret_group].GetOrAllocParam(new_arg, type);
new_args.push_back(param);
} else {
new_args.push_back(new_arg);
}
}
Array<Expr> new_args = GetNewArguments(call->args, ret_group);

auto new_call = CallNode::make(
call->op, new_args, call->attrs, call->type_args);

if (ret_group->root_ref == call) {
// This is the root of the group
// create the new call node.
const GroupInfo& ginfo = ginfo_[ret_group];
auto func = FunctionNode::make(
ginfo.params, new_call, call->checked_type(), {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
return CallNode::make(func, ginfo.arguments, Attrs());
return MakeNewFunction(ret_group, call->checked_type(), new_call);
} else {
// This is an intermediate node of a fused function
// simply return the new call.
Expand All @@ -747,6 +733,51 @@ class FuseMutator : private ExprMutator {
return ExprMutator::VisitExpr_(call);
}
}

Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
masahi marked this conversation as resolved.
Show resolved Hide resolved
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
Tuple new_tuple = TupleNode::make(new_fields);
if (ret_group == gmap_.at(tuple)) {
bool isolated = true;
for (size_t i = 0; i < new_fields.size(); ++i) {
isolated &= (new_fields[i] == ginfo_[ret_group].params[i]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to new_fields[i].same_as(ginfo_[ret_group].params[i]), in case == get overloaded in the future

}
if (isolated) {
// Do not put a isolated tuple into a function
return ExprMutator::VisitExpr_(tuple);
}
// This tuple has been fused with other ops before it
return MakeNewFunction(ret_group, tuple->checked_type(), new_tuple);
}
// This tuple is an intermediate node in the group
return new_tuple;
}

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
return CallNode::make(func, ginfo.arguments, Attrs());
}

Array<Expr> GetNewArguments(const tvm::Array<Expr>& args,
GraphPartitioner::Group* current_group) {
Array<Expr> new_args;
for (auto arg : args) {
auto* arg_group = gmap_.at(arg.get())->FindRoot();
auto type = arg->checked_type();
Expr new_arg = this->Mutate(arg);
if (current_group != arg_group) {
Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type);
new_args.push_back(param);
} else {
new_args.push_back(new_arg);
}
}
return new_args;
}

// Debug function, dump the group assignment in text.
void DebugDumpGroup(const Expr& body) {
std::string text = RelayPrint(body, false, [this](const Expr& expr) -> std::string {
Expand Down
81 changes: 79 additions & 2 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ def expected():
assert relay.ir_pass.alpha_equal(zz, after)




def test_conv2d_fuse():
"""Test fusion case of conv2d"""
def before(dshape):
Expand Down Expand Up @@ -106,7 +104,86 @@ def expected(dshape):
assert relay.ir_pass.alpha_equal(zz, after)


def test_concatenate():
"""Test fusion case involving concat op and Tuple node"""

def before(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW")
concat = relay.concatenate((upsampled, x), axis=1)
out = relay.add(concat, relay.const(1, "float32"))
return relay.Function(relay.ir_pass.free_vars(out), out)

def expected(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
f0 = relay.Function([x], pooled)

p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=dshape)
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
concat = relay.concatenate((upsampled, p1), axis=1)
out = relay.add(concat, relay.const(1, "float32"))
f1 = relay.Function([p0, p1], out)

x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
z = relay.Call(f1, [y, x])
return relay.Function([x], z)

dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)


def test_tuple_root():
"""Test fusion case where Tuple node is the root in its group"""

def before(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW")
out = relay.Tuple((upsampled, x))
return relay.Function(relay.ir_pass.free_vars(out), out)

def expected(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
f0 = relay.Function([x], pooled)

p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3]))
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
out = relay.Tuple((upsampled, p1))
f1 = relay.Function([p0, p1], out)

x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
z = relay.Call(f1, [y, x])
return relay.Function([x], z)

dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
test_concatenate()
test_tuple_root()