Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[relay] Relay annotation and partitioning for external compilers #4570

Merged
merged 7 commits into from
Jan 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
}
};

/*!
* \brief Options for the operators used to annotate a compiler.
*/
struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
/*! \brief A 3rd party compiler for code generation. */
std::string compiler;

TVM_DECLARE_ATTRS(CompilerAttrs, "relay.attrs.CompilerAttrs") {
TVM_ATTR_FIELD(compiler)
.describe("A 3rd party compiler used for code generation.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
6 changes: 3 additions & 3 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ using FTVMSchedule = runtime::TypedPackedFunc<
* operator with other expressions. This function will be invoked
* in AlterOpLayout pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param args The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \return new_expr The modified expression.
Expand Down Expand Up @@ -153,8 +153,8 @@ using FTVMConvertOpLayout = runtime::TypedPackedFunc<
* \brief Legalizes an expression with another expression. This function will be
* invoked in Legalize pass. It is a target-dependent pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* \param args The input symbols of the original node.
* \param arg_types An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \return new_expr The modified expression.
*/
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
*/
TVM_DLL Pass PrintIR(bool show_meta_data = true);

/*!
* \brief Partition a Relay program into regions that can be executed on
* different backends.
*
* \return The pass.
*/
TVM_DLL Pass PartitionGraph();

} // namespace transform

/*!
Expand Down
41 changes: 41 additions & 0 deletions python/tvm/relay/op/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def stop_fusion(data):
"""
return _make.stop_fusion(data)


def checkpoint(data):
"""Annotate an expression to be a checkpoint for the checkpointing memory optimization.

Expand All @@ -78,3 +79,43 @@ def checkpoint(data):
return _make.checkpoint(data)

register_schedule("annotation.checkpoint", schedule_injective)


def compiler_begin(data, compiler):
"""Annotate an expression to indicate that it is the beginning of
a regeion that will be handled by the given compiler.

Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.

compiler : Str
The compiler used to generate code of the annotated region.

Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.compiler_begin(data, compiler)


def compiler_end(data, compiler):
"""Annotate an expression to indicate that it is the end of a region that
is handled by the provided compiler.

Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.

compiler : Str
The compiler used to generate code of the annotated region.

Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.compiler_end(data, compiler)
12 changes: 12 additions & 0 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,18 @@ def PrintIR(show_meta_data=True):
return _transform.PrintIR(show_meta_data)


def PartitionGraph():
"""Partition a Relay program into regions that can be executed on different
backends.

Returns
-------
ret: tvm.relay.Pass
The registered pass that partitions the Relay program.
"""
return _transform.PartitionGraph()


def gradient(expr, mod=None, mode='higher_order'):
"""
Transform the input function,
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {

if (ref->IsInstance<FunctionNode>()) {
GenDNNLFunc(Downcast<Function>(ref));
} else if (ref->IsInstance<relay::ModuleNode>()) {
relay::Module mod = Downcast<relay::Module>(ref);
} else if (ref->IsInstance<IRModuleNode>()) {
IRModule mod = Downcast<IRModule>(ref);
for (const auto& it : mod->functions) {
GenDNNLFunc(Downcast<Function>(it.second));
}
Expand Down
50 changes: 50 additions & 0 deletions src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,55 @@ Mark a checkpoint for checkpointing memory optimization.
return outputs;
});

RELAY_REGISTER_OP("annotation.compiler_begin")
.describe(R"code(
Beginning of a region that is handled by a given compiler.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});

TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin")
.set_body_typed([](Expr expr, std::string compiler) {
auto attrs = make_object<CompilerAttrs>();
attrs->compiler = compiler;
static const Op& op = Op::Get("annotation.compiler_begin");
return CallNode::make(op, {expr}, Attrs(attrs), {});
});

RELAY_REGISTER_OP("annotation.compiler_end")
.describe(R"code(
End of a region that is handled by a given compiler.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});

TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end")
.set_body_typed([](Expr expr, std::string compiler) {
auto attrs = make_object<CompilerAttrs>();
attrs->compiler = compiler;
static const Op& op = Op::Get("annotation.compiler_end");
return CallNode::make(op, {expr}, Attrs(attrs), {});
});

} // namespace relay
} // namespace tvm
3 changes: 1 addition & 2 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
// Finally if the operator position is not a call node we will
// need to call Update, as it may be an arbitrary expression.
OpPatternKind op_pattern = kOpaque;
const OpNode* opnode = call->op.as<OpNode>();
if (opnode != nullptr && call->op != Op::Get("nn.batch_norm")) {
if (const OpNode* opnode = call->op.as<OpNode>()) {
op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
} else {
this->Update(call->op, node, kOpaque);
Expand Down
Loading