Skip to content

Commit

Permalink
[REFACTOR][TIR] Migrate most of low-level build to use the Pass Manag…
Browse files Browse the repository at this point in the history
…er. (#5225)

* [REFACTOR][TIR] Migrate most of low-level build to use the Pass Manager.

- SplitHostDevice
- ThreadSync
- BindDevice
- LowerThreadAllreduce
- Provide a temp fix for printing IRModule with PrimFunc before the formal text printer.

* Address comments, fix tests.

* Fix relay tests

* Explicit move
  • Loading branch information
tqchen committed Apr 3, 2020
1 parent 9b274cb commit 75e936e
Show file tree
Hide file tree
Showing 28 changed files with 465 additions and 362 deletions.
14 changes: 7 additions & 7 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,20 @@ enum class CallingConv : int {
* - Implementation: specified by the native target.
*/
kDefault = 0,
/*!
* \brief PackedFunc that exposes a CPackedFunc signature.
*
* - Calling by PackedFunc calling convention.
* - Implementation: Expose a function with the CPackedFunc signature.
*/
kCPackedFunc = 1,
/*!
* \brief Device kernel launch
*
* - Call by PackedFunc calling convention.
* - Implementation: defined by device runtime(e.g. runtime/cuda)
*/
kDeviceKernelLaunch = 2,
/*!
* \brief PackedFunc that exposes a CPackedFunc signature.
*
* - Calling by PackedFunc calling convention.
* - Implementation: Expose a function with the CPackedFunc signature.
*/
kCPackedFunc = 3,
};

/*!
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,8 @@ class IRModule : public ObjectRef {

/*! \brief Declare the container type. */
using ContainerType = IRModuleNode;
// allow copy on write.
TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode);
};

/*!
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ struct ExprDeepEqual {
public:
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
};


/*!
* \brief Find undefined vars in the statment.
* \param stmt The function to be checked.
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_ANALYSIS_H_
78 changes: 0 additions & 78 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,56 +406,6 @@ LoweredFunc MakeAPI(Stmt body,
int num_unpacked_args,
bool is_restricted);

/*!
* \brief Bind the device type of host function to be device_type.
* \param func The function to be binded.
* \param device_type The device type to be binded.
* \return The binded function.
*/
LoweredFunc BindDeviceType(LoweredFunc func,
int device_type);
/*!
* \brief Find undefined vars in the statment.
* \param stmt The function to be checked.
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);

/*!
* \brief Split the function into a host function and device functions.
* \param func The function to be splitted.
*
* \return Array of functions, the first one is host function,
* the others are device functions.
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);

/*!
* \brief Insert sync between parallel read/write of shared buffers.
*
* \param stmt The stmt to be trasnformed.
* \param storage_scope The storage scope considered.
*/
LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope);

/*!
* \brief Lower cross thread alleduce in the stmt.
* \param f The device function to be lowered.
* \param warp_size the size of warp where no sync is needed.
* \return Transformed function.
*/
LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);

/*!
* \brief Lower warp memory in stmt.
* \param f The device function to be lowered.
* \param warp_size the size of warp where no sync is needed.
* this function will only take in effect if warp_size is bigger than one.
* \return Transformed function.
*/
LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);

/*!
* \brief Remap the thread axis
*
Expand All @@ -470,26 +420,6 @@ LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
*/
LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);

/*!
* \brief Lower packed function call.
* \param f The function to be lowered.
* \return Transformed function.
*/
LoweredFunc LowerTVMBuiltin(LoweredFunc f);


/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
* to avoid pointer casting in backend when possible.
*
* \note implemeneted in storage_rewrite.cc
* \param f The function to be trasnformed
* \return Transformed function.
*/
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);

/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
Expand All @@ -513,14 +443,6 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f);
*/
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);

/*!
* \brief Infer the TensorCore fragment infomation using tensor intrinsics
*
* \param f The device function to be lowered.
* \return Transformed function.
*/
LoweredFunc InferFragment(LoweredFunc f);

/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);

/*!
* \brief Bind the device type ofthe function to be
* the device_type specified in the target attribute.
*
* \return The pass.
*/
TVM_DLL Pass BindDeviceType();

/*!
* \brief Split the function into a host function and device functions.
*
* \return The pass.
*/
TVM_DLL Pass SplitHostDevice();

/*!
* \brief skip assert stmt.
*
Expand Down
79 changes: 33 additions & 46 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name
"""The build utils in python.
This module provides the functions to transform schedule to
Expand All @@ -25,6 +27,7 @@

from tvm.runtime import ndarray
from tvm.ir import container
from tvm.ir import CallingConv
from tvm.target import codegen, BuildConfig
from tvm.tir import ir_pass
from tvm.tir.stmt import LoweredFunc
Expand Down Expand Up @@ -222,75 +225,59 @@ def _build_for_device(flist, target, target_host):
mdev : tvm.module
A module that contains device code.
"""
@tvm.tir.transform.prim_func_pass(opt_level=0)
class BindTarget:
def __init__(self, target):
self.target = target

# pylint: disable=unused-argument
def transform_function(self, func, mod, ctx):
return func.with_attr("target", self.target)

target = _target.create(target)
target_host = _target.create(target_host)
device_type = ndarray.context(target.target_name, 0).device_type
fhost = []
fdevice = []

for func in flist:
if not ir_pass.VerifyMemory(func, device_type):
raise ValueError(
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
if func.func_type == LoweredFunc.MixedFunc:
if BuildConfig.current().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
func = ir_pass.ThreadSync(func, "warp")
func = ir_pass.InferFragment(func)
warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = list(ir_pass.SplitHostDevice(func))
fhost.append(fsplits[0])
for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)

if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target)

fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
mod_mixed = tvm.testing.LoweredFuncsToIRModule(flist)
opt_mixed = [tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))]
if BuildConfig.current().detect_global_barrier:
opt_mixed += [tvm.tir.transform.ThreadSync("global")]
opt_mixed += [tvm.tir.transform.ThreadSync("shared"),
tvm.tir.transform.ThreadSync("warp"),
tvm.tir.transform.InferFragment(),
tvm.tir.transform.LowerThreadAllreduce(),
tvm.tir.transform.BindDeviceType(),
tvm.tir.transform.SplitHostDevice()]
mod_mixed = tvm.ir.transform.Sequential(opt_mixed)(mod_mixed)

if device_type == ndarray.cpu(0).device_type and target_host == target:
assert not fdevice

target_host = _target.create(target_host)

# device optimizations
mod_dev = tvm.testing.LoweredFuncsToIRModule(fdevice)
opt_device = tvm.ir.transform.Sequential(
[BindTarget(target),
[tvm.tir.transform.Filter(
lambda f: "calling_conv" in f.attrs and
f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH),
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin()])
mod_dev = opt_device(mod_dev)
mod_dev = opt_device(mod_mixed)

# host optimizations
mod_host = tvm.testing.LoweredFuncsToIRModule(fhost)
opt_host = tvm.ir.transform.Sequential(
[BindTarget(target_host),
[tvm.tir.transform.Filter(
lambda f: "calling_conv" not in f.attrs or
f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH),
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin(),
tvm.tir.transform.CombineContextCall()])
mod_host = opt_host(mod_host)
mod_host = opt_host(mod_mixed)

if device_type == ndarray.cpu(0).device_type and target_host == target:
assert len(mod_dev.functions) == 0
if "gpu" in target.keys and len(mod_dev.functions) == 0:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target)

rt_mod_dev = codegen.build_module(mod_dev, target) if fdevice else None
rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
return mod_host, rt_mod_dev


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
from .function import BaseFunc
from .function import CallingConv, BaseFunc
from .adt import Constructor, TypeData
from .module import IRModule
from .attrs import Attrs, DictAttrs, make_node
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/ir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,18 @@
# specific language governing permissions and limitations
# under the License.
"""Function defintiions."""
from enum import IntEnum
from .expr import RelayExpr
from . import _ffi_api


class CallingConv(IntEnum):
"""Possible kinds of calling conventions."""
DEFAULT = 0
C_PACKED_FUNC = 1
DEVICE_KERNEL_LAUNCH = 2


class BaseFunc(RelayExpr):
"""Base class of all functions."""
@property
Expand Down
1 change: 0 additions & 1 deletion python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(self, functions=None, type_definitions=None):
type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)


def __setitem__(self, var, val):
"""Add a mapping to the module.
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/tir/transform/function_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""TIR specific function pass support."""
import inspect
import types
import functools

import tvm._ffi
Expand Down Expand Up @@ -142,7 +143,7 @@ def create_function_pass(pass_arg):
return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _ffi_api.MakeFunctionPass(pass_arg, info)
return _ffi_api.CreatePrimFuncPass(pass_arg, info)

if pass_func:
return create_function_pass(pass_func)
Expand Down
Loading

0 comments on commit 75e936e

Please sign in to comment.