Skip to content

Commit

Permalink
[REFACTOR] TVM_REGISTER_API -> TVM_REGISTER_GLOBAL (apache#4621)
Browse files Browse the repository at this point in the history
TVM_REGSISTER_API is an alias of TVM_REGISTER_GLOBAL.
In the spirit of simplify redirections, this PR removes
the original TVM_REGISTER_API macro and directly use TVM_REGISTER_GLOBAL.

This type of refactor will also simplify the IDE navigation tools
such as FFI navigator to provide better code reading experiences.

Move EnvFunc's definition to node.
  • Loading branch information
tqchen authored and zhiics committed Mar 2, 2020
1 parent bd96887 commit 3aee8b5
Show file tree
Hide file tree
Showing 131 changed files with 602 additions and 550 deletions.
4 changes: 2 additions & 2 deletions docs/dev/codebase_walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ The Node system is the basis of exposing C++ types to frontend languages, includ

::

TVM_REGISTER_API("_ComputeOp")
TVM_REGISTER_GLOBAL("_ComputeOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ComputeOpNode::make(args[0],
args[1],
Expand Down Expand Up @@ -174,7 +174,7 @@ The ``Build()`` function looks up the code generator for the given target in the

::

TVM_REGISTER_API("codegen.build_cuda")
TVM_REGISTER_GLOBAL("codegen.build_cuda")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildCUDA(args[0]);
});
Expand Down
4 changes: 2 additions & 2 deletions docs/dev/relay_add_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ the arguments to the call node, as below.

.. code:: c
TVM_REGISTER_API("relay.op._make.add")
TVM_REGISTER_GLOBAL("relay.op._make.add")
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) {
static const Op& op = Op::Get("add");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
Expand All @@ -106,7 +106,7 @@ Including a Python API Hook
---------------------------

It is generally the convention in Relay, that functions exported
through ``TVM_REGISTER_API`` should be wrapped in a separate
through ``TVM_REGISTER_GLOBAL`` should be wrapped in a separate
Python function rather than called directly in Python. In the case
of the functions that produce calls to operators, it may be convenient
to bundle them, as in ``python/tvm/relay/op/tensor.py``, where
Expand Down
32 changes: 16 additions & 16 deletions docs/dev/relay_pass_infra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ Python APIs to create a compilation pipeline using pass context.
TVM_DLL static PassContext Create();
TVM_DLL static PassContext Current();
/* Other fields are omitted. */
private:
// The entry of a pass context scope.
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();

// Classes to get the Python `with` like syntax.
friend class tvm::With<PassContext>;
};
Expand Down Expand Up @@ -225,7 +225,7 @@ cannot add or delete a function through these passes as they are not aware of
the global information.

.. code:: c++

class FunctionPassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
Expand Down Expand Up @@ -319,7 +319,7 @@ favorably use Python APIs to create a specific pass object.
ModulePass CreateModulePass(std::string name,
int opt_level,
PassFunc pass_func);

SequentialPass CreateSequentialPass(std::string name,
int opt_level,
Array<Pass> passes,
Expand Down Expand Up @@ -347,14 +347,14 @@ registration.
auto tensor_type = relay::TensorTypeNode::make({}, tvm::Bool());
auto x = relay::VarNode::make("x", relay::Type());
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});

auto y = relay::VarNode::make("y", tensor_type);
auto call = relay::CallNode::make(f, tvm::Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});

// Create a module for optimization.
auto mod = relay::ModuleNode::FromExpr(fx);

// Create a sequential pass.
tvm::Array<relay::transform::Pass> pass_seqs{
relay::transform::InferType(),
Expand All @@ -363,7 +363,7 @@ registration.
relay::transform::AlterOpLayout()
};
relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);

// Create a pass context for the optimization.
auto ctx = relay::transform::PassContext::Create();
ctx->opt_level = 2;
Expand Down Expand Up @@ -421,7 +421,7 @@ Python when needed.
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}

TVM_REGISTER_API("relay._transform.FoldConstant")
TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);

} // namespace transform
Expand Down Expand Up @@ -457,10 +457,10 @@ a certain scope.
def __enter__(self):
_transform.EnterPassContext(self)
return self
def __exit__(self, ptype, value, trace):
_transform.ExitPassContext(self)
@staticmethod
def current():
"""Return the current pass context."""
Expand Down Expand Up @@ -580,18 +580,18 @@ using ``Sequential`` associated with other types of passes.
z1 = relay.add(y, c)
z2 = relay.add(z, z1)
func = relay.Function([x], z2)
# Customize the optimization pipeline.
# Customize the optimization pipeline.
seq = _transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
relay.transform.AlterOpLayout()
])
# Create a module to perform optimizations.
mod = relay.Module({"main": func})
# Users can disable any passes that they don't want to execute by providing
# a list, e.g. disabled_pass=["EliminateCommonSubexpr"].
with relay.build_config(opt_level=3):
Expand Down Expand Up @@ -629,7 +629,7 @@ For more pass infra related examples in Python and C++, please refer to

.. _Block: https://mxnet.incubator.apache.org/api/python/docs/api/gluon/block.html#gluon-block

.. _Relay module: https://docs.tvm.ai/langref/relay_expr.html#module-and-global-functions
.. _Relay module: https://docs.tvm.ai/langref/relay_expr.html#module-and-global-functions

.. _include/tvm/relay/transform.h: https://github.com/apache/incubator-tvm/blob/master/include/tvm/relay/transform.h

Expand Down
1 change: 0 additions & 1 deletion include/tvm/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "base.h"
#include "expr.h"
#include "lowered_func.h"
#include "api_registry.h"
#include "runtime/packed_func.h"

namespace tvm {
Expand Down
32 changes: 9 additions & 23 deletions include/tvm/api_registry.h → include/tvm/node/env_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,19 @@
*/

/*!
* \file tvm/api_registry.h
* \brief This file contains utilities related to
* the TVM's global function registry.
* \file tvm/node/env_func.h
* \brief Serializable global function.
*/
#ifndef TVM_API_REGISTRY_H_
#define TVM_API_REGISTRY_H_
#ifndef TVM_NODE_ENV_FUNC_H_
#define TVM_NODE_ENV_FUNC_H_

#include <tvm/node/reflection.h>

#include <string>
#include <utility>
#include "base.h"
#include "packed_func_ext.h"
#include "runtime/registry.h"

namespace tvm {
/*!
* \brief Register an API function globally.
* It simply redirects to TVM_REGISTER_GLOBAL
*
* \code
* TVM_REGISTER_API(MyPrint)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#define TVM_REGISTER_API(OpName) TVM_REGISTER_GLOBAL(OpName)

namespace tvm {
/*!
* \brief Node container of EnvFunc
* \sa EnvFunc
Expand All @@ -54,7 +40,7 @@ class EnvFuncNode : public Object {
/*! \brief Unique name of the global function */
std::string name;
/*! \brief The internal packed function */
PackedFunc func;
runtime::PackedFunc func;
/*! \brief constructor */
EnvFuncNode() {}

Expand Down Expand Up @@ -154,4 +140,4 @@ class TypedEnvFunc<R(Args...)> : public ObjectRef {
};

} // namespace tvm
#endif // TVM_API_REGISTRY_H_
#endif // TVM_NODE_ENV_FUNC_H_
2 changes: 1 addition & 1 deletion include/tvm/relay/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_BASE_H_
#define TVM_RELAY_BASE_H_

#include <tvm/api_registry.h>

#include <tvm/ir/span.h>
#include <tvm/ir.h>
#include <tvm/node/node.h>
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@
#ifndef TVM_RELAY_TYPE_H_
#define TVM_RELAY_TYPE_H_

#include <tvm/api_registry.h>

#include <tvm/ir/type.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/node/env_func.h>

#include <tvm/ir.h>
#include <string>

Expand Down
12 changes: 6 additions & 6 deletions include/tvm/runtime/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Registry {
*
* \code
*
* TVM_REGISTER_API("addone")
* TVM_REGISTER_GLOBAL("addone")
* .set_body_typed<int(int)>([](int x) { return x + 1; });
*
* \endcode
Expand All @@ -96,7 +96,7 @@ class Registry {
* return x * y;
* }
*
* TVM_REGISTER_API("multiply")
* TVM_REGISTER_GLOBAL("multiply")
* .set_body_typed(multiply); // will have type int(int, int)
*
* \endcode
Expand All @@ -120,7 +120,7 @@ class Registry {
* struct Example {
* int doThing(int x);
* }
* TVM_REGISTER_API("Example_doThing")
* TVM_REGISTER_GLOBAL("Example_doThing")
* .set_body_method(&Example::doThing); // will have type int(Example, int)
*
* \endcode
Expand Down Expand Up @@ -148,7 +148,7 @@ class Registry {
* struct Example {
* int doThing(int x);
* }
* TVM_REGISTER_API("Example_doThing")
* TVM_REGISTER_GLOBAL("Example_doThing")
* .set_body_method(&Example::doThing); // will have type int(Example, int)
*
* \endcode
Expand Down Expand Up @@ -181,7 +181,7 @@ class Registry {
* // noderef subclass
* struct Example;
*
* TVM_REGISTER_API("Example_doThing")
* TVM_REGISTER_GLOBAL("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
*
* // note that just doing:
Expand Down Expand Up @@ -221,7 +221,7 @@ class Registry {
* // noderef subclass
* struct Example;
*
* TVM_REGISTER_API("Example_doThing")
* TVM_REGISTER_GLOBAL("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
*
* // note that just doing:
Expand Down
32 changes: 17 additions & 15 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,31 @@
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>

#include <tvm/tensor.h>

namespace tvm {
namespace arith {

TVM_REGISTER_API("arith.intset_single_point")
TVM_REGISTER_GLOBAL("arith.intset_single_point")
.set_body_typed(IntSet::single_point);

TVM_REGISTER_API("arith.intset_vector")
TVM_REGISTER_GLOBAL("arith.intset_vector")
.set_body_typed(IntSet::vector);

TVM_REGISTER_API("arith.intset_interval")
TVM_REGISTER_GLOBAL("arith.intset_interval")
.set_body_typed(IntSet::interval);


TVM_REGISTER_API("arith.DetectLinearEquation")
TVM_REGISTER_GLOBAL("arith.DetectLinearEquation")
.set_body_typed(DetectLinearEquation);

TVM_REGISTER_API("arith.DetectClipBound")
TVM_REGISTER_GLOBAL("arith.DetectClipBound")
.set_body_typed(DetectClipBound);

TVM_REGISTER_API("arith.DeduceBound")
TVM_REGISTER_GLOBAL("arith.DeduceBound")
.set_body_typed<IntSet(Expr, Expr, Map<Var, IntSet>, Map<Var, IntSet>)>([](
Expr v, Expr cond,
const Map<Var, IntSet> hint_map,
Expand All @@ -55,36 +57,36 @@ TVM_REGISTER_API("arith.DeduceBound")
});


TVM_REGISTER_API("arith.DomainTouched")
TVM_REGISTER_GLOBAL("arith.DomainTouched")
.set_body_typed(DomainTouched);

TVM_REGISTER_API("_IntervalSetGetMin")
TVM_REGISTER_GLOBAL("_IntervalSetGetMin")
.set_body_method(&IntSet::min);

TVM_REGISTER_API("_IntervalSetGetMax")
TVM_REGISTER_GLOBAL("_IntervalSetGetMax")
.set_body_method(&IntSet::max);

TVM_REGISTER_API("_IntSetIsNothing")
TVM_REGISTER_GLOBAL("_IntSetIsNothing")
.set_body_method(&IntSet::is_nothing);

TVM_REGISTER_API("_IntSetIsEverything")
TVM_REGISTER_GLOBAL("_IntSetIsEverything")
.set_body_method(&IntSet::is_everything);

ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
return ConstIntBound(min_value, max_value);
}

TVM_REGISTER_API("arith._make_ConstIntBound")
TVM_REGISTER_GLOBAL("arith._make_ConstIntBound")
.set_body_typed(MakeConstIntBound);

ModularSet MakeModularSet(int64_t coeff, int64_t base) {
return ModularSet(coeff, base);
}

TVM_REGISTER_API("arith._make_ModularSet")
TVM_REGISTER_GLOBAL("arith._make_ModularSet")
.set_body_typed(MakeModularSet);

TVM_REGISTER_API("arith._CreateAnalyzer")
TVM_REGISTER_GLOBAL("arith._CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
Expand Down
Loading

0 comments on commit 3aee8b5

Please sign in to comment.