Skip to content

Commit

Permalink
[TIR] Refactor MakePackedAPI to target dependent stage. (#5326)
Browse files Browse the repository at this point in the history
Previously MakePackedAPI was in the target independent stage,
but never the less requires the device_type information that will be
binded at a later target dependent stage.

The previous implementation was due to the limitation of LoweredFunc
which can not carry buffer_map info(so they have to be lowered right away).
This is no longer the case after the unified IR refactor.

This PR migrates MakePackedAPI to a target dependent stage
and removes the un-necessary BindDevice pass.
  • Loading branch information
tqchen committed Apr 14, 2020
1 parent 4720cf8 commit f08d5d7
Show file tree
Hide file tree
Showing 21 changed files with 114 additions and 187 deletions.
9 changes: 8 additions & 1 deletion include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,19 @@ class Sequential : public Pass {
*
* \return The created module pass.
*/
Pass CreateModulePass(
TVM_DLL Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const Array<runtime::String>& required);


/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \return The pass.
*/
TVM_DLL Pass PrintIR(std::string header);

} // namespace transform
} // namespace tvm

Expand Down
9 changes: 9 additions & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ class TVM_DLL DeviceAPI {
* \return The corresponding device API.
*/
static DeviceAPI* Get(TVMContext ctx, bool allow_missing = false);

/*!
* \brief Whether a certian device type requires set device context
* before launching the kernel function.
* \param device_type The device type.
*/
static bool NeedSetDeviceContext(int device_type) {
return device_type != kDLCPU && device_type != kDLMicroDev;
}
};

/*! \brief The device type bigger than this is RPC device */
Expand Down
9 changes: 0 additions & 9 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,6 @@ TVM_DLL Pass RemapThreadAxis(Map<runtime::String, IterVar> axis_map);
*/
TVM_DLL Pass LowerCustomDatatypes();


/*!
* \brief Bind the device type ofthe function to be
* the device_type specified in the target attribute.
*
* \return The pass.
*/
TVM_DLL Pass BindDeviceType();

/*!
* \brief Split the function into a host function and device functions.
*
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def lower(sch,
if cfg.restricted_func:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: f})
return tvm.tir.transform.MakePackedAPI()(mod)
return mod


def _build_for_device(input_mod, target, target_host):
Expand Down Expand Up @@ -243,13 +243,13 @@ def _build_for_device(input_mod, target, target_host):
tvm.tir.transform.ThreadSync("warp"),
tvm.tir.transform.InferFragment(),
tvm.tir.transform.LowerThreadAllreduce(),
tvm.tir.transform.BindDeviceType(),
tvm.tir.transform.MakePackedAPI(),
tvm.tir.transform.SplitHostDevice()]
mod_mixed = tvm.ir.transform.Sequential(opt_mixed)(mod_mixed)
mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)


# device optimizations
opt_device = tvm.ir.transform.Sequential(
opt_device = tvm.transform.Sequential(
[tvm.tir.transform.Filter(
lambda f: "calling_conv" in f.attrs and
f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH),
Expand All @@ -259,7 +259,7 @@ def _build_for_device(input_mod, target, target_host):
mod_dev = opt_device(mod_mixed)

# host optimizations
opt_host = tvm.ir.transform.Sequential(
opt_host = tvm.transform.Sequential(
[tvm.tir.transform.Filter(
lambda f: "calling_conv" not in f.attrs or
f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH),
Expand Down
27 changes: 21 additions & 6 deletions python/tvm/ir/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@

import tvm._ffi

from tvm._ffi.runtime_ctypes import TVMContext
from tvm.runtime import Object, ndarray as _nd
import tvm.runtime
from tvm.runtime import ndarray as _nd

from . import _ffi_transform_api

@tvm._ffi.register_object("transform.PassInfo")
class PassInfo(Object):
class PassInfo(tvm.runtime.Object):
"""The class contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis.
This class can be extended by adding new members when more meta data is
Expand All @@ -52,7 +52,7 @@ def __init__(self, opt_level, name, required=None):


@tvm._ffi.register_object("transform.PassContext")
class PassContext(Object):
class PassContext(tvm.runtime.Object):
"""The basis where a Relay optimization/analysis runs on.
Each pass context contains a number of auxiliary information that is used
to help an optimization pass. Such information includes the error reporter
Expand All @@ -79,7 +79,7 @@ def __init__(self,
trace=None):
if isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, TVMContext):
elif isinstance(fallback_device, tvm.runtime.TVMContext):
fallback_device = fallback_device.device_type
if not isinstance(fallback_device, int):
raise TypeError("fallback_device is expected to be the type of " +
Expand Down Expand Up @@ -113,7 +113,7 @@ def current():


@tvm._ffi.register_object("transform.Pass")
class Pass(Object):
class Pass(tvm.runtime.Object):
"""The base class of all passes. All methods here are just simple wrappers
that are implemented in the backend. They are defined for users to
conveniently interact with the base class.
Expand Down Expand Up @@ -327,3 +327,18 @@ def create_module_pass(pass_arg):
if pass_func:
return create_module_pass(pass_func)
return create_module_pass


def PrintIR(header):
"""A special trace pass that prints the header and IR.
Parameters
----------
header : str
The header to be displayed along with the dump.
Returns
--------
The pass
"""
return _ffi_transform_api.PrintIR(header)
3 changes: 2 additions & 1 deletion python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,14 @@ def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
mod : IRModule
The created IRModule.
"""
assert num_unpacked_args == 0
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: f})
return tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)
return mod


tvm._ffi._init_api("testing", __name__)
42 changes: 15 additions & 27 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def Apply(ftransform):
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
"""
# pylint: disable=unused-argument
Expand All @@ -51,7 +51,7 @@ def Filter(fcond):
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
"""
# pylint: disable=unused-argument
Expand All @@ -67,7 +67,7 @@ def LowerCustomDatatypes():
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerCustomDatatypes()
Expand All @@ -84,30 +84,18 @@ def MakePackedAPI(num_unpacked_params=0):
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MakePackedAPI(num_unpacked_params)


def BindDeviceType():
"""Bind the device type of the function to be
the device_type specified in the target attribute.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.BindDeviceType()


def SplitHostDevice():
"""Split the function into a host function and device functions.
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.SplitHostDevice()
Expand All @@ -118,7 +106,7 @@ def SkipAssert():
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.SkipAssert()
Expand All @@ -134,7 +122,7 @@ def ThreadSync(storage_scope):
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.ThreadSync(storage_scope)
Expand All @@ -145,7 +133,7 @@ def LowerThreadAllreduce():
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerThreadAllreduce()
Expand All @@ -156,7 +144,7 @@ def InferFragment():
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InferFragment()
Expand All @@ -167,7 +155,7 @@ def LowerWarpMemory():
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerWarpMemory()
Expand All @@ -178,7 +166,7 @@ def LowerTVMBuiltin():
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerTVMBuiltin()
Expand All @@ -189,7 +177,7 @@ def LowerIntrin():
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerIntrin()
Expand All @@ -200,7 +188,7 @@ def LowerDeviceStorageAccessInfo():
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
Note
Expand All @@ -215,7 +203,7 @@ def CombineContextCall():
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.CombineContextCall()
Expand All @@ -231,7 +219,7 @@ def NarrowDataType(target_bits):
Returns
-------
fpass : tvm.ir.transform.Pass
fpass : tvm.transform.Pass
The result pass
Note
Expand Down
5 changes: 2 additions & 3 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ IRModule lower(te::Schedule sch,
if (config->restricted_func) {
f = WithAttr(std::move(f), "tir.noalias", Integer(1));
}
auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
return tir::transform::MakePackedAPI(0)(mod);
return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
}


Expand All @@ -237,7 +236,7 @@ split_dev_host_funcs(IRModule mod_mixed,
mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
mixed_pass_list.push_back(tir::transform::BindDeviceType());
mixed_pass_list.push_back(tir::transform::MakePackedAPI(0));
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
auto opt_mixed = transform::Sequential(mixed_pass_list);
mod_mixed = opt_mixed(std::move(mod_mixed));
Expand Down
13 changes: 13 additions & 0 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,5 +473,18 @@ TVM_REGISTER_GLOBAL("transform.EnterPassContext")
TVM_REGISTER_GLOBAL("transform.ExitPassContext")
.set_body_typed(PassContext::Internal::ExitScope);


Pass PrintIR(std::string header) {
auto pass_func =[header](IRModule mod, const PassContext& ctx) {
LOG(INFO) << "PrintIR(" << header << "):\n"
<< mod;
return mod;
};
return CreateModulePass(pass_func, 0, "PrintIR", {});
}

TVM_REGISTER_GLOBAL("transform.PrintIR")
.set_body_typed(PrintIR);

} // namespace transform
} // namespace tvm
2 changes: 2 additions & 0 deletions src/target/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ StackVM::StructFieldKind MapFieldKind(int64_t kind) {
}

StackVM CodeGenStackVM::Compile(const PrimFunc& f) {
CHECK_EQ(f->buffer_map.size(), 0U)
<< "Cannot codegen function with buffer_map, please lower them first";
for (size_t i = 0; i < f->params.size(); ++i) {
Var v = f->params[i];
int vid = AllocVarID(v.get());
Expand Down
6 changes: 4 additions & 2 deletions src/tir/analysis/verify_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,11 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
/// Check if the value of a Variable comes from function argument.
bool IsFromFunctionArgs(const VarNode *var) const {
const VarNode *V = var;
while (true) {
CHECK(V) << "Invalid Variable\n";
for (auto kv : func_->buffer_map) {
if (V == kv.second->data.get()) return true;
}

while (true) {
// Variable is from function args. Return true.
if (V == func_->params[0].get()) return true;

Expand Down
Loading

0 comments on commit f08d5d7

Please sign in to comment.