diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index fd21db5a9c147..4481d2adcd5fe 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -57,6 +57,19 @@ struct CastHintAttrs : public tvm::AttrsNode { } }; +/*! + * \brief Options for the operators used to annotate a compiler. + */ +struct CompilerAttrs : public tvm::AttrsNode { + /*! \brief A 3rd party compiler for code generation. */ + std::string compiler; + + TVM_DECLARE_ATTRS(CompilerAttrs, "relay.attrs.CompilerAttrs") { + TVM_ATTR_FIELD(compiler) + .describe("A 3rd party compiler used for code generation."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ANNOTATION_H_ diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 9cfa755ef8132..b6221e0ba8a5b 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -123,7 +123,7 @@ using FTVMSchedule = runtime::TypedPackedFunc< * operator with other expressions. This function will be invoked * in AlterOpLayout pass. * \param attrs The attribute of the original node. - * \param inputs The input symbols of the original node. + * \param args The input symbols of the original node. * \param tinfos An array of placeholders, use for getting the inferred shape * and dtype of the inputs. * \return new_expr The modified expression. @@ -153,8 +153,8 @@ using FTVMConvertOpLayout = runtime::TypedPackedFunc< * \brief Legalizes an expression with another expression. This function will be * invoked in Legalize pass. It is a target-dependent pass. * \param attrs The attribute of the original node. - * \param inputs The input symbols of the original node. - * \param tinfos An array of placeholders, use for getting the inferred shape + * \param args The input symbols of the original node. + * \param arg_types An array of placeholders, use for getting the inferred shape * and dtype of the inputs. * \return new_expr The modified expression. */ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 294ffb995c5e7..58cfbfcc2b1d0 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -310,6 +310,14 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var); */ TVM_DLL Pass PrintIR(bool show_meta_data = true); +/*! + * \brief Partition a Relay program into regions that can be executed on + * different backends. + * + * \return The pass. + */ +TVM_DLL Pass PartitionGraph(); + } // namespace transform /*! diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index 2b9d4bcd81bc3..93639251beab0 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -62,6 +62,7 @@ def stop_fusion(data): """ return _make.stop_fusion(data) + def checkpoint(data): """Annotate an expression to be a checkpoint for the checkpointing memory optimization. @@ -78,3 +79,43 @@ def checkpoint(data): return _make.checkpoint(data) register_schedule("annotation.checkpoint", schedule_injective) + + +def compiler_begin(data, compiler): + """Annotate an expression to indicate that it is the beginning of + a regeion that will be handled by the given compiler. + + Parameters + ---------- + data : tvm.relay.Expr + The expression to be annotated. + + compiler : Str + The compiler used to generate code of the annotated region. + + Returns + ------- + result : tvm.relay.Expr + The annotated expression. + """ + return _make.compiler_begin(data, compiler) + + +def compiler_end(data, compiler): + """Annotate an expression to indicate that it is the end of a region that + is handled by the provided compiler. + + Parameters + ---------- + data : tvm.relay.Expr + The expression to be annotated. + + compiler : Str + The compiler used to generate code of the annotated region. + + Returns + ------- + result : tvm.relay.Expr + The annotated expression. + """ + return _make.compiler_end(data, compiler) diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 1f91272769b4b..c4fbde60a6eb9 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -663,6 +663,18 @@ def PrintIR(show_meta_data=True): return _transform.PrintIR(show_meta_data) +def PartitionGraph(): + """Partition a Relay program into regions that can be executed on different + backends. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that partitions the Relay program. + """ + return _transform.PartitionGraph() + + def gradient(expr, mod=None, mode='higher_order'): """ Transform the input function, diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 9c24944c83b48..fbe047d26a5cd 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -270,8 +270,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { if (ref->IsInstance()) { GenDNNLFunc(Downcast(ref)); - } else if (ref->IsInstance()) { - relay::Module mod = Downcast(ref); + } else if (ref->IsInstance()) { + IRModule mod = Downcast(ref); for (const auto& it : mod->functions) { GenDNNLFunc(Downcast(it.second)); } diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index efcb383d5e9d6..3d03f884e2470 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -171,5 +171,55 @@ Mark a checkpoint for checkpointing memory optimization. return outputs; }); +RELAY_REGISTER_OP("annotation.compiler_begin") +.describe(R"code( +Beginning of a region that is handled by a given compiler. +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_support_level(10) +.add_type_rel("Identity", IdentityRel) +.set_attr("TOpPattern", kOpaque) +.set_attr("TOpIsStateful", false) +.set_attr("FInferCorrectLayout", + ElemwiseArbitraryLayout) +.set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype, const Target& target) -> Array { + return {topi::identity(inputs[0])}; + }); + +TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin") +.set_body_typed([](Expr expr, std::string compiler) { + auto attrs = make_object(); + attrs->compiler = compiler; + static const Op& op = Op::Get("annotation.compiler_begin"); + return CallNode::make(op, {expr}, Attrs(attrs), {}); +}); + +RELAY_REGISTER_OP("annotation.compiler_end") +.describe(R"code( +End of a region that is handled by a given compiler. +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_support_level(10) +.add_type_rel("Identity", IdentityRel) +.set_attr("TOpPattern", kOpaque) +.set_attr("TOpIsStateful", false) +.set_attr("FInferCorrectLayout", + ElemwiseArbitraryLayout) +.set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype, const Target& target) -> Array { + return {topi::identity(inputs[0])}; + }); + +TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end") +.set_body_typed([](Expr expr, std::string compiler) { + auto attrs = make_object(); + attrs->compiler = compiler; + static const Op& op = Op::Get("annotation.compiler_end"); + return CallNode::make(op, {expr}, Attrs(attrs), {}); +}); + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index bf38a48b2528b..e18dbc27d3678 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -242,8 +242,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // Finally if the operator position is not a call node we will // need to call Update, as it may be an arbitrary expression. OpPatternKind op_pattern = kOpaque; - const OpNode* opnode = call->op.as(); - if (opnode != nullptr && call->op != Op::Get("nn.batch_norm")) { + if (const OpNode* opnode = call->op.as()) { op_pattern = static_cast(fpattern[GetRef(opnode)]); } else { this->Update(call->op, node, kOpaque); diff --git a/src/relay/pass/partition_graph.cc b/src/relay/pass/partition_graph.cc new file mode 100644 index 0000000000000..634affebdebd2 --- /dev/null +++ b/src/relay/pass/partition_graph.cc @@ -0,0 +1,386 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file src/relay/pass/partition_graph.cc + * + * \brief Partition an input function into multiple functions according based + * on the inserted annotation nodes (i.e. compiler_begin and compiler_end). + * These nodes are used as boundaries to partition the Relay function into + * multiple regions that can be offloaded to different accelerators/backends. + * + * Each of these paritioned functions, a.k.a subgraphs, will be viewed as + * external functions, and they will use the provided compiler for codegen. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace partitioning { + +// Cache compiler_begin and compiler_end annotation ops for equivalence check to +// reduce registry lookup overhead. +static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); +static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); + +/*! + * \brief The subgraph properties for partitioning. + */ +struct Subgraph { + /*! \brief The subgraph ID. */ + int id; + + /*! \brief The input arguments of this subgraph. */ + std::vector> args; + + /*! \brief Nodes in this subgraph. */ + std::unordered_set nodes; +}; + +/*! + * \brief The checker that verifies if a Relay program is annotated correctly + * for partitioning. + */ +class AnnotationChecker : public ExprVisitor { + public: + bool Check() { + if (!found_start_ && !found_end_) { + LOG(WARNING) << "No compiler annotation found"; + } else if (!found_start_) { + LOG(ERROR) << "compiler_begin annotation is missing"; + return false; + } else if (!found_end_) { + LOG(ERROR) << "compiler_end annotation is missing"; + return false; + } + return true; + } + + void VisitExpr_(const CallNode* call) final { + auto op_node = call->op.as(); + if (op_node == nullptr || call->attrs.as() == nullptr) { + return; + } else if (call->op == compiler_begin_op) { + found_start_ = true; + } else if (call->op == compiler_end_op) { + found_end_ = true; + } + } + + private: + bool found_start_{false}; + bool found_end_{false}; +}; + +/*! \brief This class partitions the expr labeled with begin and end annoations + * into function containing multiple regions. Each region is labeled with + * a compiler attribute so that it will be handled by any compilers that are not + * in the TVM stack. + * + * TODO(@zhiics) This following algorithm is not adequate to handle all cases, + * i.e. multiple `compiler_end` nodes. + */ +class Partitioner : public ExprMutator { + public: + std::shared_ptr GetSubgraph(const Expr node) { + for (auto candidate : this->subgraphs_) { + if (candidate->nodes.find(node) != candidate->nodes.end()) { + return candidate; + } + } + return nullptr; + } + + void MergeSubgraph(std::shared_ptr subgraph1, + std::shared_ptr subgraph2) { + if (subgraph1 == subgraph2) { + return; + } + + // Merge subgraph 2 to subgraph 1 and erase subgraph 2. + subgraph1->nodes.insert(subgraph2->nodes.begin(), subgraph2->nodes.end()); + for (auto arg : subgraph2->args) { + subgraph1->args.push_back(arg); + } + this->subgraphs_.erase(subgraph2); + } + + void AddToSubgraph(std::shared_ptr subgraph, const Expr expr) { + auto subgraph2 = GetSubgraph(expr); + if (subgraph2) { + MergeSubgraph(subgraph, subgraph2); + } else { + subgraph->nodes.insert(expr); + } + } + + Expr VisitExpr_(const CallNode* call) final { + auto op_node = call->op.as(); + + if (op_node == nullptr || call->attrs.as() == nullptr) { + // Propogate subgraph to arguments + auto subgraph = GetSubgraph(GetRef(call)); + if (subgraph) { + for (auto arg : call->args) { + AddToSubgraph(subgraph, arg); + } + } + return ExprMutator::VisitExpr_(call); + } else if (call->op == compiler_begin_op) { + // The annotation node is inserted on edge so it must have only one argument. + CHECK_EQ(call->args.size(), 1U); + + // Traverse the rest graph. + auto input_expr = VisitExpr(call->args[0]); + + // Replace the begin annotation with an external call input variable. + auto compiler_attrs = call->attrs.as(); + auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++), + input_expr->checked_type_); + + // Find the corresponding subgraph and add the argument. + auto subgraph = GetSubgraph(GetRef(call)); + if (!subgraph) { + throw Error(ErrorBuilder() + << "Cannot find the corresponding subgraph for start annotation:\n" + << AsText(GetRef(call), false)); + } + subgraph->args.push_back({var, input_expr}); + return std::move(var); + } else { + CHECK_EQ(call->op, compiler_end_op); + // The annotation node is inserted on edge so it must have only one argument. + CHECK_EQ(call->args.size(), 1U); + + auto compiler_attrs = call->attrs.as(); + + // Check if the argument already belongs to an exist subgraph + auto subgraph = GetSubgraph(call->args[0]); + if (!subgraph) { + auto ret = this->subgraphs_.emplace(std::make_shared()); + subgraph = *ret.first; + subgraph->nodes.insert(call->args[0]); + subgraph->id = this->subgraph_id_++; + } + subgraph->nodes.insert(GetRef(call)); + + // Traverse subgraph inputs. + auto input = VisitExpr(call->args[0]); + Array params; + Array args; + + // The subgraph may be merged so we need to update it again. + subgraph = GetSubgraph(GetRef(call)); + CHECK(subgraph); + + for (auto pair : subgraph->args) { + params.push_back(pair.first); + args.push_back(pair.second); + } + + auto subgraph_func = + FunctionNode::make(params, input, call->args[0]->checked_type_, {}, Attrs()); + + Expr arg0 = call->args[0]; + std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id); + subgraph_func = + FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tvm::ir::StringImmNode::make(name)); + subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1)); + subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler, + tvm::ir::StringImmNode::make(compiler_attrs->compiler)); + return CallNode::make(subgraph_func, args); + } + } + + Expr VisitExpr_(const TupleNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + for (auto field : op->fields) { + AddToSubgraph(subgraph, field); + } + Array fields; + for (auto field : op->fields) { + fields.push_back(VisitExpr(field)); + } + return TupleNode::make(fields); + } + } + + Expr VisitExpr_(const TupleGetItemNode* g) final { + auto subgraph = GetSubgraph(GetRef(g)); + if (!subgraph) { + return ExprMutator::VisitExpr_(g); + } else { + AddToSubgraph(subgraph, g->tuple); + auto t = VisitExpr(g->tuple); + return TupleGetItemNode::make(t, g->index); + } + } + + Expr VisitExpr_(const FunctionNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + Array params; + for (auto param : op->params) { + AddToSubgraph(subgraph, param); + } + for (auto param : op->params) { + Var new_param = Downcast(VisitExpr(param)); + params.push_back(new_param); + } + auto body = VisitExpr(op->body); + return FunctionNode::make(params, body, op->ret_type, op->type_params, op->attrs); + } + } + + Expr VisitExpr_(const LetNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->var); + AddToSubgraph(subgraph, op->value); + AddToSubgraph(subgraph, op->body); + Var var = Downcast(VisitExpr(op->var)); + auto value = VisitExpr(op->value); + auto body = VisitExpr(op->body); + + return LetNode::make(var, value, body); + } + } + + Expr VisitExpr_(const IfNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->cond); + AddToSubgraph(subgraph, op->true_branch); + AddToSubgraph(subgraph, op->false_branch); + auto guard = VisitExpr(op->cond); + auto true_b = VisitExpr(op->true_branch); + auto false_b = VisitExpr(op->false_branch); + return IfNode::make(guard, true_b, false_b); + } + } + + Expr VisitExpr_(const RefCreateNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->value); + Expr value = VisitExpr(op->value); + return RefCreateNode::make(value); + } + } + + Expr VisitExpr_(const RefReadNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->ref); + Expr ref = VisitExpr(op->ref); + return RefReadNode::make(ref); + } + } + + Expr VisitExpr_(const RefWriteNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->ref); + Expr ref = VisitExpr(op->ref); + Expr value = VisitExpr(op->value); + return RefWriteNode::make(ref, value); + } + } + + private: + int var_id_{0}; + int subgraph_id_{0}; + std::unordered_set> subgraphs_; +}; + +/*! + * \brief TODO(@zhiics, @comaniac) Combine parallel regions that belong to + * the same codegen backend. This reduces rounds trips between TVM and external + * backends. Likely we can borrow some ideas from operator fusion. + * + * For example, sg1 and sg2 should be combined if they belong to the same + * codegen tool in the following case. + * + * op1 + * / \ + * sg1 sg2 + * + * | + * \|/ + * + * op1 + * | + * sg1_sg2 + * + * where the return type of the new subgraph sg1_sg2 is a tuple, and op1 has two + * inputs that obtained from the tuple. + */ + +Expr PartitionGraph(const Expr& expr) { + Partitioner part; + return part.Mutate(expr); +} + +} // namespace partitioning + +namespace transform { + +Pass PartitionGraph() { + runtime::TypedPackedFunc part_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(partitioning::PartitionGraph(f)); + }; + auto partitioned = CreateFunctionPass(part_func, 0, "PartitionGraph", {}); + return Sequential({partitioned, InferType()}); +} + +TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph") +.set_body_typed(transform::PartitionGraph); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py new file mode 100644 index 0000000000000..4ffb373116968 --- /dev/null +++ b/tests/python/relay/test_pass_partition_graph.py @@ -0,0 +1,434 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for graph partitioning.""" +import os +import sys +import numpy as np +import pytest + +import tvm +import tvm.relay.testing +import tvm.relay.transform as transform +from tvm import relay +from tvm.contrib import util +from tvm.relay.annotation import compiler_begin, compiler_end +from tvm.relay.expr_functor import ExprMutator + +# Leverage the pass manager to write a simple white list based annotator +@transform.function_pass(opt_level=0) +class WhiteListAnnotator: + def __init__(self, op_list, compiler): + assert isinstance(op_list, (list, tuple, set)) + self.op_list = op_list + self.compiler = compiler + + def transform_function(self, func, mod, ctx): + + annotator = self + class Annotator(tvm.relay.ExprMutator): + def visit_call(self, call): + op_name = call.op.name + if op_name in annotator.op_list: + new_args = [] + for arg in call.args: + ann = compiler_begin(super().visit(arg), + annotator.compiler) + new_args.append(ann) + new_call = relay.Call(call.op, new_args, call.attrs, + call.type_args) + return compiler_end(new_call, annotator.compiler) + else: + return super().visit_call(call) + return Annotator().visit(func) + + +class CcompilerAnnotator(ExprMutator): + """ + A simple annotator that creates the following program: + | + -- begin -- + | + add + | + subtract + | + multiply + | + -- end -- + | + """ + + def __init__(self): + super(CcompilerAnnotator, self).__init__() + self.in_compiler = 0 + + def visit_call(self, call): + if call.op.name == "add": # Annotate begin at args + if self.in_compiler == 1: + lhs = compiler_begin(super().visit(call.args[0]), "ccompiler") + rhs = compiler_begin(super().visit(call.args[1]), "ccompiler") + op = relay.add(lhs, rhs) + self.in_compiler = 2 + return op + elif call.op.name == "subtract": + if self.in_compiler == 1: + lhs = super().visit(call.args[0]) + rhs = super().visit(call.args[1]) + if isinstance(lhs, relay.expr.Var): + lhs = compiler_begin(lhs, "ccompiler") + if isinstance(rhs, relay.expr.Var): + rhs = compiler_begin(rhs, "ccompiler") + return relay.subtract(lhs, rhs) + elif call.op.name == "multiply": # Annotate end at output + self.in_compiler = 1 + lhs = super().visit(call.args[0]) + rhs = super().visit(call.args[1]) + if isinstance(lhs, relay.expr.Var): + lhs = compiler_begin(lhs, "ccompiler") + if isinstance(rhs, relay.expr.Var): + rhs = compiler_begin(rhs, "ccompiler") + op = relay.multiply(lhs, rhs) + if self.in_compiler == 2: + op = compiler_end(op, "ccompiler") + self.in_compiler = 0 + return op + return super().visit_call(call) + + +class WholeGraphAnnotator(ExprMutator): + """ + An annotator that creates a compiler for an entire graph. + """ + + def __init__(self, compiler): + super(WholeGraphAnnotator, self).__init__() + self.compiler = compiler + self.last_call = True + + def visit_call(self, call): + curr_last = self.last_call + self.last_call = False + + params = [] + for arg in call.args: + param = super().visit(arg) + if isinstance(param, relay.expr.Var): + param = compiler_begin(param, self.compiler) + params.append(param) + + new_call = relay.Call(call.op, params, call.attrs) + if curr_last: + new_call = compiler_end(new_call, self.compiler) + return new_call + + +class MobileNetAnnotator(ExprMutator): + """ + Annotate mobilenet until global_avg_pool. + """ + + def __init__(self, compiler): + super(MobileNetAnnotator, self).__init__() + self.compiler = compiler + self.compiler_open = False + + def visit_call(self, call): + + if call.op.name == 'nn.global_avg_pool2d': + self.compiler_open = True + compiler_open = self.compiler_open + + params = [] + for arg in call.args: + param = super().visit(arg) + if call.op.name == 'nn.global_avg_pool2d': + param = compiler_end(param, self.compiler) + if compiler_open and isinstance(param, relay.expr.Var): + param = compiler_begin(param, self.compiler) + params.append(param) + + new_call = relay.Call(call.op, params, call.attrs) + return new_call + + +def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", + ctx=tvm.cpu(), params=None): + if sys.platform == "win32": + print("Skip test on Windows for now") + return + + def update_lib(lib): + test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) + source_dir = os.path.join(test_dir, "..", "..", "..") + contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") + + kwargs = {} + kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] + tmp_path = util.tempdir() + lib_name = 'lib.so' + lib_path = tmp_path.relpath(lib_name) + lib.export_library(lib_path, fcompile=False, **kwargs) + lib = tvm.module.load(lib_path) + + return lib + + def check_vm_result(): + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + exe = relay.vm.compile(mod, target=target, params=params) + code, lib = exe.save() + lib = update_lib(lib) + exe = relay.vm.Executable.load_exec(code, lib) + vm = relay.vm.VirtualMachine(exe) + vm.init(ctx) + out = vm.run(**map_inputs) + tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) + + def check_graph_runtime_result(): + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + json, lib, param = relay.build(mod, target=target, params=params) + lib = update_lib(lib) + rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx) + + for name, data in map_inputs.items(): + rt_mod.set_input(name, data) + rt_mod.set_input(**param) + rt_mod.run() + out = tvm.nd.empty(out_shape, ctx=ctx) + out = rt_mod.get_output(0, out) + + tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) + + check_vm_result() + check_graph_runtime_result() + + +def test_multi_node_compiler(): + x = relay.var('x', shape=(10, 10)) + w0 = relay.var('w0', shape=(10, 10)) + w1 = relay.var('w1', shape=(10, 10)) + w2 = relay.var('w2', shape=(10, 10)) + w3 = relay.var('w3', shape=(10, 10)) + w4 = relay.var('w4', shape=(10, 10)) + w5 = relay.var('w5', shape=(10, 10)) + w6 = relay.var('w6', shape=(10, 10)) + w7 = relay.var('w7', shape=(10, 10)) + + # C compiler + # FIXME: We generate two compilers for this case but they should be merged to one + # due to the common input (x). + z0 = relay.add(x, w0) + p0 = relay.subtract(z0, w1) + q0 = relay.multiply(p0, w2) + + z1 = relay.add(x, w3) + p1 = relay.subtract(z1, w4) + q1 = relay.multiply(p1, w5) + + # Other parts on TVM + z2 = relay.add(x, w6) + q2 = relay.subtract(z2, w7) + + r = relay.concatenate((q0, q1, q2), axis=0) + f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) + mod = relay.Module() + ann = CcompilerAnnotator() + mod["main"] = ann.visit(f) + mod = transform.PartitionGraph()(mod) + mod = transform.InferType()(mod) + + x_data = np.random.rand(10, 10).astype('float32') + w_data = [] + for _ in range(8): + w_data.append(np.random.rand(10, 10).astype('float32')) + + map_inputs = {"w{}".format(i): w_data[i] for i in range(8)} + map_inputs["x"] = x_data + check_result( + mod, map_inputs, (30, 10), + np.concatenate((((x_data + w_data[0]) - w_data[1]) * w_data[2], + ((x_data + w_data[3]) - w_data[4]) * w_data[5], + x_data + w_data[6] - w_data[7]), + axis=0)) + + +def test_extern_ccompiler_single_op(): + @transform.function_pass(opt_level=0) + class MyAnnotator: + def transform_function(self, func, mod, ctx): + class Annotator(tvm.relay.ExprMutator): + def visit_call(self, call): + new_args = [] + for arg in call.args: + ann = compiler_begin(self.visit(arg), "ccompiler") + new_args.append(ann) + new_call = relay.Call(call.op, new_args) + return compiler_end(new_call, "ccompiler") + return Annotator().visit(func) + + x = relay.var('x', shape=(8, 8)) + y = relay.var('y', shape=(8, 8)) + z = x + y + f = relay.Function([x, y], z) + x_data = np.random.rand(8, 8).astype('float32') + y_data = np.random.rand(8, 8).astype('float32') + mod = relay.Module() + mod["main"] = f + mod = MyAnnotator()(mod) + mod = transform.PartitionGraph()(mod) + + check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data) + + +def test_extern_ccompiler_default_ops(): + def expected(): + x = relay.var("x", shape=(8, 8)) + y = relay.var("y", shape=(8, 8)) + x0 = relay.var("x0", shape=(8, 8)) + y0 = relay.var("y0", shape=(8, 8)) + add = x0 + y0 + # Function that uses C compiler + func = relay.Function([x0, y0], add) + func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) + func = func.set_attribute("Compiler", + tvm.expr.StringImm("ccompiler")) + func = func.set_attribute("ExternalSymbol", + tvm.expr.StringImm("ccompiler_0")) + add_call = relay.Call(func, [x, y]) + # Function that uses default compiler. Ops are fused in this function. + p0 = relay.var("p0", shape=(8, 8)) + log = relay.log(p0) + exp = relay.exp(p0) + concat = relay.concatenate([log, exp], axis=0) + fused_func = relay.Function([p0], concat) + fused_func = fused_func.set_attribute("Primitive", + tvm.expr.IntImm("int32", 1)) + fused_call = relay.Call(fused_func, [add_call]) + main = relay.Function([x, y], fused_call) + mod = relay.Module() + mod["main"] = main + return mod + + x = relay.var("x", shape=(8, 8)) + y = relay.var("y", shape=(8, 8)) + add = x + y + log = relay.log(add) + exp = relay.exp(add) + concat = relay.concatenate([log, exp], axis=0) + f = relay.Function([x, y], concat) + mod = relay.Module() + mod["main"] = f + mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod) + mod = transform.PartitionGraph()(mod) + + fused_mod = transform.FuseOps(2)(mod) + expected_mod = expected() + assert relay.alpha_equal(fused_mod, expected_mod) + + x_data = np.random.rand(8, 8).astype('float32') + y_data = np.random.rand(8, 8).astype('float32') + np_add = x_data + y_data + res = np.concatenate([np.log(np_add), np.exp(np_add)]) + check_result(mod, {"x": x_data, "y": y_data}, (16, 8), res) + + +def test_extern_ccompiler(): + x = relay.var('x', shape=(2, 2)) + y = relay.var('y', shape=(2, 2)) + z = x + x + p = y * y + f = relay.Function([x, y], p - z) + x_data = np.random.rand(2, 2).astype('float32') + y_data = np.random.rand(2, 2).astype('float32') + mod = relay.Module() + mod["main"] = f + mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod) + mod = transform.PartitionGraph()(mod) + + check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data)) + + +def test_extern_dnnl(): + if not tvm.get_global_func("relay.ext.dnnl", True): + print("skip because DNNL codegen is not available") + return + + dtype = 'float32' + ishape = (1, 32, 14, 14) + w1shape = (32, 1, 3, 3) + data = relay.var('data', shape=(ishape), dtype=dtype) + weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype) + depthwise_conv2d_1 = relay.nn.conv2d(data, + weight1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1, + weight1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) + + f = relay.Function([data, weight1], out) + + mod = relay.Module() + mod['main'] = WholeGraphAnnotator('dnnl').visit(f) + mod = transform.PartitionGraph()(mod) + + ref_mod = relay.Module() + ref_mod['main'] = f + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) + + ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu()) + ref_res = ref_ex.evaluate()(i_data, w1_data) + check_result(mod, {"data": i_data, "weight1": w1_data}, + (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5) + + +def test_extern_dnnl_mobilenet(): + if not tvm.get_global_func("relay.ext.dnnl", True): + print("skip because DNNL codegen is not available") + return + + dtype = 'float32' + ishape = (1, 3, 224, 224) + mod, params = relay.testing.mobilenet.get_workload( + batch_size=1, dtype='float32') + + op_list = ["nn.conv2d", "nn.dense", "nn.relu", "add"] + mod = WhiteListAnnotator(op_list, "dnnl")(mod) + mod = transform.PartitionGraph()(mod) + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + + ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1, + dtype='float32') + ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0)) + ref_res = ref_ex.evaluate()(i_data, **params) + + check_result(mod, {"data": i_data}, + (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params) + + +if __name__ == "__main__": + test_multi_node_compiler() + test_extern_ccompiler_single_op() + test_extern_ccompiler_default_ops() + test_extern_ccompiler() + test_extern_dnnl() + test_extern_dnnl_mobilenet()