Skip to content

Commit

Permalink
[BYOC] Support control flow in annotate_target (#6641)
Browse files Browse the repository at this point in the history
* Change annotate target

* Annotate_target

* Revert namespace changes

* Add tests for if-else node

* Add while_let testcase

* No merging in ifelse

* Remove scope builder

* Add ops

* Replace < with less

* Linter

* Pass Tests

* Change back to static const

* Cpplinter

* address PR comments'

* PR Comments

* Clang-format check

* PR Comments

* PR Comments

* Change back to Insert Ann in AnnotateARgs

Co-authored-by: Ritwik Das <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
3 people authored Oct 13, 2020
1 parent 4073adc commit d5728bd
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 21 deletions.
66 changes: 45 additions & 21 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,25 @@ class AnnotateTargetRewriter : public ExprRewriter {
return new_op;
}

Expr InsertCompilerEndAndPropogateTarget(const Expr& expr) {
/*!
* \brief This function inserts compiler end to expr and maps the corresponding target to the
* new expression.
*
* This function checks for expr existence within the map and inserts the annotation
* Further, it propagates the target to the new expression and returns it
*
* \param expr A relay expression
* \return An annotated and target-propagated relay expression.
*/
Expr new_expr = expr;
if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) {
new_expr = InsertAnnotation(expr, op_expr_to_target_[expr], make_end_op);
op_expr_to_target_[new_expr] = op_expr_to_target_[expr];
}
return std::move(new_expr);
}

Expr Rewrite_(const CallNode* pre, const Expr& post) final {
// Supported targets for this node. The order implies the priority.
std::vector<std::string> supported_targets;
Expand All @@ -127,14 +146,16 @@ class AnnotateTargetRewriter : public ExprRewriter {
CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end());
return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op);
}

// Peek the first argument. If it is compiler begin then this node had annotated by
// another target before, so we also consider that target as a supported target.
const CallNode* first_arg_call = pre->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == CompilerBeginOp()) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
supported_targets.push_back(arg_target);
// Check prior to peeking first argument
if (pre->args.size()) {
// Peek the first argument. If it is compiler begin then this node had annotated by
// another target before, so we also consider that target as a supported target.
const CallNode* first_arg_call = pre->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == CompilerBeginOp()) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
supported_targets.push_back(arg_target);
}
}
}

Expand Down Expand Up @@ -222,32 +243,35 @@ class AnnotateTargetRewriter : public ExprRewriter {
new_body = func->body;
} else {
func = Downcast<Function>(post);
new_body = func->body;
if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) {
new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op);
op_expr_to_target_[new_body] = op_expr_to_target_[func->body];
}
new_body = InsertCompilerEndAndPropogateTarget(func->body);
}
return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs);
}

Expr Rewrite_(const LetNode* op, const Expr& post) final {
auto let = Downcast<Let>(post);

auto target_n_args = AnnotateArgs({let->value, let->body});
auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
Expr new_expr;
std::pair<std::string, Array<Expr>> target_n_args;
Expr new_body = InsertCompilerEndAndPropogateTarget(let->body);
// Do not annotate function literal with let binding.
if (let->value->IsInstance<FunctionNode>()) {
new_expr = Let(let->var, let->value, new_body);
} else {
target_n_args = AnnotateArgs({let->value});
new_expr = Let(let->var, std::get<1>(target_n_args)[0], new_body);
}

return std::move(new_expr);
}

Expr Rewrite_(const IfNode* op, const Expr& post) final {
auto expr = Downcast<If>(post);
Expr new_cond = InsertCompilerEndAndPropogateTarget(expr->cond);
Expr new_true_branch = InsertCompilerEndAndPropogateTarget(expr->true_branch);
Expr new_false_branch = InsertCompilerEndAndPropogateTarget(expr->false_branch);

auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch});
CHECK_EQ(std::get<1>(target_n_args).size(), 3U);
auto new_expr = If(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1],
std::get<1>(target_n_args)[2]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
auto new_expr = If(new_cond, new_true_branch, new_false_branch);
return std::move(new_expr);
}

Expand Down
157 changes: 157 additions & 0 deletions tests/python/relay/test_pass_annotate_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,161 @@ def before():
assert tvm.ir.structural_equal(expected, mod)


def test_if_else():
target = "test_if_else"

@tvm.ir.register_op_attr("equal", "target." + target)
def relu(attrs, args): # pylint: disable=unused-variable
return True

@tvm.ir.register_op_attr("tanh", "target." + target)
def tanh(attrs, args): # pylint: disable=unused-variable
return True

@tvm.ir.register_op_attr("sigmoid", "target." + target)
def sigmoid(attrs, args): # pylint: disable=unused-variable
return True

@tvm.ir.register_op_attr("erf", "target." + target)
def erf(attrs, args): # pylint: disable=unused-variable
return True

"""Test that If-else nodes compiles correctly when surrounded by supported nodes."""

def before():
data = relay.var("data", shape=(1, 32))
eq1 = relay.var("e1", shape=[], dtype="float32")
eq2 = relay.var("e2", shape=[], dtype="float32")
eq = relay.equal(eq1, eq2)

true_branch = relay.tanh(data)
false_branch = relay.sigmoid(data)
ife = relay.If(eq, true_branch, false_branch)
out = relay.erf(ife)
func = relay.Function([data, eq1, eq2], out)
mod = tvm.IRModule.from_expr(func)

return mod

def after():

data = relay.var("data", shape=(1, 32))
eq1 = relay.var("e1", shape=[], dtype="float32")
eq2 = relay.var("e2", shape=[], dtype="float32")

cb_1 = relay.annotation.compiler_begin(eq1, target)
cb_2 = relay.annotation.compiler_begin(eq2, target)

equality_condition = relay.equal(cb_1, cb_2)
ce_1 = relay.annotation.compiler_end(equality_condition, target)

# if condition
cb_3 = relay.annotation.compiler_begin(data, target)
true_branch = relay.tanh(cb_3)
ce_2 = relay.annotation.compiler_end(true_branch, target)

# else condition
cb_4 = relay.annotation.compiler_begin(data, target)
false_branch = relay.sigmoid(cb_4)
ce_3 = relay.annotation.compiler_end(false_branch, target)

if_condition = relay.If(ce_1, ce_2, ce_3)
cb_5 = relay.annotation.compiler_begin(if_condition, target)
erf_out = relay.erf(cb_5)
ce_4 = relay.annotation.compiler_end(erf_out, target)
func = relay.Function([data, eq1, eq2], ce_4)
mod = tvm.IRModule.from_expr(func)
return mod

result = transform.AnnotateTarget(target)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)


def test_while_let():
target = "test_while_let"

@tvm.ir.register_op_attr("less", "target." + target)
def less(attrs, args): # pylint: disable=unused-variable
return True

@tvm.ir.register_op_attr("add", "target." + target)
def add(attrs, args): # pylint: disable=unused-variable
return True

@tvm.ir.register_op_attr("zeros_like", "target." + target)
def zeros_like(attrs, args): # pylint: disable=unused-variable
return True

"""Test that let nodes compiles correctly when surrounded by other nodes."""

def before():

var1 = relay.var("var1", shape=(2,))
var2 = relay.var("var2", shape=(), dtype="int32")
var3 = relay.var("var3", shape=(2,))
cond = relay.less(var2, relay.const(10, dtype="int32"))

loop = relay.var("while_loop")
ii = var2 + relay.const(1, dtype="int32")
ss = var3 + var1
true_branch = loop(ii, ss)
ife = relay.If(cond, true_branch, var3)
func_1 = relay.Function([var2, var3], ife)

ret = relay.Let(loop, func_1, loop(relay.const(0, dtype="int32"), relay.zeros_like(var1)))
func_2 = relay.Function([var1], ret)
mod = tvm.IRModule.from_expr(func_2)
return mod

def after():
var1 = relay.var("var1", shape=(2,))
var2 = relay.var("var2", shape=(), dtype="int32")
var3 = relay.var("var3", shape=(2,))
var4 = relay.const(10, dtype="int32")

cb_1 = relay.annotation.compiler_begin(var2, target)
cb_2 = relay.annotation.compiler_begin(var4, target)

less_condition = relay.less(cb_1, cb_2)
ce_1 = relay.annotation.compiler_end(less_condition, target)

loop = relay.var("while_loop")

# if condition
cb_3 = relay.annotation.compiler_begin(var2, target)
cb_4 = relay.annotation.compiler_begin(relay.const(1, dtype="int32"), target)
add_op_1 = relay.add(cb_3, cb_4)
ce_2 = relay.annotation.compiler_end(add_op_1, target)
cb_5 = relay.annotation.compiler_begin(ce_2, "default")
cb_6 = relay.annotation.compiler_begin(var3, target)
cb_7 = relay.annotation.compiler_begin(var1, target)
add_op_2 = relay.add(cb_6, cb_7)
ce_3 = relay.annotation.compiler_end(add_op_2, target)
cb_8 = relay.annotation.compiler_begin(ce_3, "default")
true_branch = loop(cb_5, cb_8) # while loop
ce_4 = relay.annotation.compiler_end(true_branch, "default")
if_condition = relay.If(ce_1, ce_4, var3)

cb_9 = relay.annotation.compiler_begin(relay.const(0, dtype="int32"), "default")
cb_10 = relay.annotation.compiler_begin(var1, target)
zeros_like = relay.zeros_like(cb_10)
ce_5 = relay.annotation.compiler_end(zeros_like, target)
cb_11 = relay.annotation.compiler_begin(ce_5, "default")
while_condition = loop(cb_9, cb_11)
ce_6 = relay.annotation.compiler_end(while_condition, "default")

func_1 = relay.Function([var2, var3], if_condition)
ret = relay.Let(loop, func_1, ce_6)
func_2 = relay.Function([var1], ret)
mod = tvm.IRModule.from_expr(func_2)
return mod

result = transform.AnnotateTarget(target)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)


if __name__ == "__main__":
test_extern_dnnl()
test_composite_function()
Expand All @@ -363,3 +518,5 @@ def before():
test_type_propagation()
test_tuple()
test_multiple_runs()
test_if_else()
test_while_let()

0 comments on commit d5728bd

Please sign in to comment.