Skip to content

Commit

Permalink
[relay] Relay annotation and partitioning for external compilers (#4570)
Browse files Browse the repository at this point in the history
* [relay] Relay annotation and partitioning for codegen

* Add fusion unit test

* fix comments

* Update include/tvm/relay/attrs/annotation.h

Co-Authored-By: 雾雨魔理沙 <[email protected]>

* rebase

* remove annotation helper

* rebase again

Co-authored-by: Cody Yu <[email protected]>
Co-authored-by: 雾雨魔理沙 <[email protected]>
  • Loading branch information
3 people committed Jan 14, 2020
1 parent d7d2a9b commit 3f2abfb
Show file tree
Hide file tree
Showing 10 changed files with 950 additions and 7 deletions.
13 changes: 13 additions & 0 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
}
};

/*!
* \brief Options for the operators used to annotate a compiler.
*/
struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
/*! \brief A 3rd party compiler for code generation. */
std::string compiler;

TVM_DECLARE_ATTRS(CompilerAttrs, "relay.attrs.CompilerAttrs") {
TVM_ATTR_FIELD(compiler)
.describe("A 3rd party compiler used for code generation.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
6 changes: 3 additions & 3 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ using FTVMSchedule = runtime::TypedPackedFunc<
* operator with other expressions. This function will be invoked
* in AlterOpLayout pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param args The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \return new_expr The modified expression.
Expand Down Expand Up @@ -153,8 +153,8 @@ using FTVMConvertOpLayout = runtime::TypedPackedFunc<
* \brief Legalizes an expression with another expression. This function will be
* invoked in Legalize pass. It is a target-dependent pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* \param args The input symbols of the original node.
* \param arg_types An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \return new_expr The modified expression.
*/
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
*/
TVM_DLL Pass PrintIR(bool show_meta_data = true);

/*!
* \brief Partition a Relay program into regions that can be executed on
* different backends.
*
* \return The pass.
*/
TVM_DLL Pass PartitionGraph();

} // namespace transform

/*!
Expand Down
41 changes: 41 additions & 0 deletions python/tvm/relay/op/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def stop_fusion(data):
"""
return _make.stop_fusion(data)


def checkpoint(data):
"""Annotate an expression to be a checkpoint for the checkpointing memory optimization.
Expand All @@ -78,3 +79,43 @@ def checkpoint(data):
return _make.checkpoint(data)

register_schedule("annotation.checkpoint", schedule_injective)


def compiler_begin(data, compiler):
"""Annotate an expression to indicate that it is the beginning of
a regeion that will be handled by the given compiler.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
compiler : Str
The compiler used to generate code of the annotated region.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.compiler_begin(data, compiler)


def compiler_end(data, compiler):
"""Annotate an expression to indicate that it is the end of a region that
is handled by the provided compiler.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
compiler : Str
The compiler used to generate code of the annotated region.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.compiler_end(data, compiler)
12 changes: 12 additions & 0 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,18 @@ def PrintIR(show_meta_data=True):
return _transform.PrintIR(show_meta_data)


def PartitionGraph():
"""Partition a Relay program into regions that can be executed on different
backends.
Returns
-------
ret: tvm.relay.Pass
The registered pass that partitions the Relay program.
"""
return _transform.PartitionGraph()


def gradient(expr, mod=None, mode='higher_order'):
"""
Transform the input function,
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {

if (ref->IsInstance<FunctionNode>()) {
GenDNNLFunc(Downcast<Function>(ref));
} else if (ref->IsInstance<relay::ModuleNode>()) {
relay::Module mod = Downcast<relay::Module>(ref);
} else if (ref->IsInstance<IRModuleNode>()) {
IRModule mod = Downcast<IRModule>(ref);
for (const auto& it : mod->functions) {
GenDNNLFunc(Downcast<Function>(it.second));
}
Expand Down
50 changes: 50 additions & 0 deletions src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,55 @@ Mark a checkpoint for checkpointing memory optimization.
return outputs;
});

RELAY_REGISTER_OP("annotation.compiler_begin")
.describe(R"code(
Beginning of a region that is handled by a given compiler.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});

TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin")
.set_body_typed([](Expr expr, std::string compiler) {
auto attrs = make_object<CompilerAttrs>();
attrs->compiler = compiler;
static const Op& op = Op::Get("annotation.compiler_begin");
return CallNode::make(op, {expr}, Attrs(attrs), {});
});

RELAY_REGISTER_OP("annotation.compiler_end")
.describe(R"code(
End of a region that is handled by a given compiler.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});

TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end")
.set_body_typed([](Expr expr, std::string compiler) {
auto attrs = make_object<CompilerAttrs>();
attrs->compiler = compiler;
static const Op& op = Op::Get("annotation.compiler_end");
return CallNode::make(op, {expr}, Attrs(attrs), {});
});

} // namespace relay
} // namespace tvm
3 changes: 1 addition & 2 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
// Finally if the operator position is not a call node we will
// need to call Update, as it may be an arbitrary expression.
OpPatternKind op_pattern = kOpaque;
const OpNode* opnode = call->op.as<OpNode>();
if (opnode != nullptr && call->op != Op::Get("nn.batch_norm")) {
if (const OpNode* opnode = call->op.as<OpNode>()) {
op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
} else {
this->Update(call->op, node, kOpaque);
Expand Down
Loading

0 comments on commit 3f2abfb

Please sign in to comment.