Skip to content

Commit

Permalink
[relay] Relay annotation and partitioning for codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Dec 23, 2019
1 parent e6ff3f7 commit 0c694d1
Show file tree
Hide file tree
Showing 19 changed files with 1,279 additions and 5 deletions.
13 changes: 13 additions & 0 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
}
};

/*!
* \brief Options for the operators used to annotate a compiler.
*/
struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
/*! \brief The 3rd party compiler for code generation. */
std::string compiler;

TVM_DECLARE_ATTRS(CompilerAttrs, "relay.attrs.CompilerAttrs") {
TVM_ATTR_FIELD(compiler)
.describe("The 3rd compiler used for code generation.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
23 changes: 20 additions & 3 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/build_module.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
#include <string>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -122,7 +123,7 @@ using FTVMSchedule = runtime::TypedPackedFunc<
* operator with other expressions. This function will be invoked
* in AlterOpLayout pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param args The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \return new_expr The modified expression.
Expand All @@ -136,8 +137,8 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc<
* \brief Legalizes an expression with another expression. This function will be
* invoked in Legalize pass. It is a target-dependent pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* \param args The input symbols of the original node.
* \param arg_types An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \return new_expr The modified expression.
*/
Expand All @@ -146,6 +147,22 @@ using FTVMLegalize = runtime::TypedPackedFunc<
const Array<Expr>& args,
const Array<tvm::relay::Type>& arg_types)>;

/*!
* \brief Annotates an expression to indicate which compiler an op
* should be used for codegen.
*
* \param attrs The attribute of the original expr.
* \param args The arguments of the original expr.
* \param compiler The compiler that is used to compile the op.
*
* \return true if this op should be registered to invoke a specific compiler
* for codegen, otherwise, false.
*/
using FTVMAnnotateCompiler = runtime::TypedPackedFunc<
bool(const Attrs& attrs, // NOLINT(*)
const Array<Expr>& args,
const std::string& compiler)>;

/*!
* \brief Forward rewriting rule for a specific op.
*
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,14 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
*/
TVM_DLL Pass PrintIR(bool show_meta_data = true);

/*!
* \brief Partition a Relay program into regions that can be executed on
* different backends.
*
* \return The pass.
*/
TVM_DLL Pass PartitionGraph();

} // namespace transform

/*!
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from . import adt
from . import analysis
from . import transform
from .build_module import build, create_executor, optimize
from .build_module import build, create_executor, optimize, build_extern_compiler
from .transform import build_config
from . import prelude
from . import parser
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .module import Module as _Module
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor
from . import transform as _transform

def _update_target(target):
target = target if target else _target.current_target()
Expand Down Expand Up @@ -296,6 +297,34 @@ def optimize(mod, target=None, params=None):
return mod, params


def build_extern_compiler(mod, compiler):
"""Helper function that annotates a Relay module and patitions the
expression init into various regions. These regions will be handled
by either default compilers in TVM stack or the provided external compiler.
Parameters
----------
mod : relay.Module
The module to build. Using relay.Function is deprecated.
compiler : str
The name of the external compiler.
Returns
-------
mod : relay.Module
The relay module contains partitioned program regions (e.g. functions)
that will be compiled using different compilers.
"""
if isinstance(mod, _expr.Function):
mod = _Module.from_expr(mod)

seq = _transform.Sequential([_transform.AnnotateCompiler(compiler),
_transform.PartitionGraph()])
mod = seq(mod)
return mod


class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# operator defs
from .op import get, register, register_schedule, register_compute, register_gradient, \
register_pattern, register_alter_op_layout, register_legalize, \
schedule_injective, Op, OpPattern, debug
register_annotate_compiler, schedule_injective, Op, OpPattern, debug

# Operators
from .reduce import *
Expand Down
41 changes: 41 additions & 0 deletions python/tvm/relay/op/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def stop_fusion(data):
"""
return _make.stop_fusion(data)


def checkpoint(data):
"""Annotate an expression to be a checkpoint for the checkpointing memory optimization.
Expand All @@ -78,3 +79,43 @@ def checkpoint(data):
return _make.checkpoint(data)

register_schedule("annotation.checkpoint", schedule_injective)


def compiler_begin(data, compiler):
"""Annotate an expression to indicate that it is the beginning of
a regeion that will be handled by the given compiler.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
compiler : Str
The compiler used to generate code of the annotated region.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.compiler_begin(data, compiler)


def compiler_end(data, compiler):
"""Annotate an expression to indicate that it is the end of a region that
is handled by the provided compiler.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
compiler : Str
The compiler used to generate code of the annotated region.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.compiler_end(data, compiler)
1 change: 1 addition & 0 deletions python/tvm/relay/op/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
"""Neural network related operators."""
from __future__ import absolute_import as _abs
from .contrib import *
from .annotate_compiler import *
from . import _contrib
119 changes: 119 additions & 0 deletions python/tvm/relay/op/contrib/annotate_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""
External compiler related feature registration.
It implements dispatchers that check if an operator should use a given compiler
to generate code.
Each compiler can customize the support of an operator. For example, they can
check the attribute of the operator and/or the features of the input arguments
to decide if we should use the compiler for codegen.
"""
from __future__ import absolute_import

import logging
import pkgutil
from pathlib import Path
from importlib import import_module

from .. import op as reg

logger = logging.getLogger('AnnotateCompiler')

# Load available contrib compilers
compilers = {}
for _, name, _ in pkgutil.iter_modules([Path(__file__).parent]):
compilers[name] = import_module(
'.%s' % name, package='.'.join(__name__.split('.')[:-1]))


def get_annotate_compiler(compiler, op_name):
"""Get the annotate_compiler function from the registered compilers.
Parameters
----------
compiler : Str
The name of a compiler that is used to generate code.
op_name : Str
The name of an operator.
Returns
-------
ret : bool
If the operator uses the provided compiler for codegen.
"""
if compiler in compilers:
if hasattr(compilers[compiler], 'annotate_compiler'):
annotate_compiler = getattr(compilers[compiler], 'annotate_compiler')
if hasattr(annotate_compiler, op_name):
return getattr(annotate_compiler, op_name)

logger.warning("%s in %s is not registered. Fallback to CPU", op_name,
compiler)
return lambda x, y: False


@reg.register_annotate_compiler("nn.conv2d")
def annotate_conv2d(attrs, args, compiler):
"""Check if the provided compiler should be used for conv2d.
"""
return get_annotate_compiler(compiler, 'conv2d')(attrs, args)


@reg.register_annotate_compiler("nn.dense")
def annotate_dense(attrs, args, compiler):
"""Check if the provided compiler should be used for dense.
"""
return get_annotate_compiler(compiler, 'dense')(attrs, args)


@reg.register_annotate_compiler("nn.relu")
def annotate_relu(attrs, args, compiler):
"""Check if the provided compiler should be used for relu.
"""
return get_annotate_compiler(compiler, 'relu')(attrs, args)


@reg.register_annotate_compiler("nn.batch_norm")
def annotate_batch_norm(attrs, args, compiler):
"""Check if the provided compiler should be used for batch_norm.
"""
return get_annotate_compiler(compiler, 'batch_norm')(attrs, args)


@reg.register_annotate_compiler("subtract")
def annotate_subtract(attrs, args, compiler):
"""Check if the provided compiler should be used for subtract.
"""
return get_annotate_compiler(compiler, 'subtract')(attrs, args)


@reg.register_annotate_compiler("add")
def annotate_add(attrs, args, compiler):
"""Check if the provided compiler should be used for add.
"""
return get_annotate_compiler(compiler, 'add')(attrs, args)


@reg.register_annotate_compiler("multiply")
def annotate_multiply(attrs, args, compiler):
"""Check if the provided compiler should be used for multiply.
"""
return get_annotate_compiler(compiler, 'multiply')(attrs, args)
20 changes: 20 additions & 0 deletions python/tvm/relay/op/contrib/csource/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Neural network related operators."""
from __future__ import absolute_import as _abs
from .annotate_compiler import *
39 changes: 39 additions & 0 deletions python/tvm/relay/op/contrib/csource/annotate_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""C/C++ compiler supported operators."""
from __future__ import absolute_import

def conv2d(attrs, args):
"""Check if the external C source codegen should be used.
"""
return False

def subtract(attrs, args):
"""Check if the external C source codegen should be used.
"""
return True

def add(attrs, args):
"""Check if the external C source codegen should be used.
"""
return True

def multiply(attrs, args):
"""Check if the external C source codegen should be used.
"""
return True
Loading

0 comments on commit 0c694d1

Please sign in to comment.