Skip to content

Commit

Permalink
[REFACTOR][TIR] Migrate low-level pass functions to Pass Manager, (#5213
Browse files Browse the repository at this point in the history
)

- Migrate LowerTVMBultin
- Migrate inferFragment, LowerThreadAllreduce
- Migrate ThreadSync
- Refactor target::Build to directly take IRModule.
- Remove un-used legacy functions.
  • Loading branch information
tqchen authored Apr 2, 2020
1 parent 88d2f34 commit 44bffdb
Show file tree
Hide file tree
Showing 28 changed files with 407 additions and 329 deletions.
14 changes: 0 additions & 14 deletions include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,6 @@ TVM_DLL Array<tir::LoweredFunc> lower(
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
const BuildConfig& config);
/*!
* \brief Split host/device function and running necessary pass before build
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
second is device function array
*/
TVM_DLL Array<Array<tir::LoweredFunc> > split_dev_host_funcs(
const Array<tir::LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);

/*!
* \brief Build a device and host module for a specific target from an array of lowered functions.
Expand Down
22 changes: 16 additions & 6 deletions include/tvm/target/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_TARGET_CODEGEN_H_

#include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/target/target.h>
Expand All @@ -40,16 +41,25 @@ using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;

/*!
* \brief Temporary backward compatible function to convert a list
* of LoweredFunc to a IRModule of PrimfFuncs
* \param funcs The input lowered function.
* \return The IRModule.
*
* \note This function is only used for code refactor and will be
* removed once the refactor completes.
*/
IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs);

/*!
* \brief Build a module from array of lowered function.
* \param funcs The functions to be built.
* \param mod The Module to be built
* \param target The target to be built.
* \return The builded module.
*
* \note Calls global API function "_codegen_build_" + target
* \return The result runtime::Module.
*/
runtime::Module Build(const Array<tir::LoweredFunc>& funcs,
const std::string& target);
runtime::Module Build(IRModule mod, const Target& target);

/*!
* \brief Pack imported device library to a C file.
* Compile the C file and link with the host library
Expand Down
31 changes: 0 additions & 31 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,12 +477,6 @@ LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
*/
LoweredFunc LowerTVMBuiltin(LoweredFunc f);

/*!
* \brief Combine context function calls.
* \param f The host function to be lowered.
* \return Transformed function.
*/
LoweredFunc CombineContextCall(LoweredFunc f);

/*!
* \brief Rewrite the pointer content type of arguments,
Expand All @@ -496,7 +490,6 @@ LoweredFunc CombineContextCall(LoweredFunc f);
*/
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);


/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
Expand All @@ -509,23 +502,6 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
*/
PrimFunc PointerValueTypeRewrite(PrimFunc f);

/*!
* \brief Lower attached storage access information on device.
* Do this pass after all storage access analysis finish.
*
* \param func The device function to be lowered.
* \return Transformed function.
*/
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func);

/*!
* \brief Lower intrinsic function calls.
* \param f The device function to be lowered.
* \param target The target device.
* \return Transformed function.
*/
LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);

/*!
* \brief Lower custom datatypes.
*
Expand All @@ -545,13 +521,6 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
*/
LoweredFunc InferFragment(LoweredFunc f);

/*!
* \brief skip assert stmt generation
* \param f The function to be transformed.
* \return Transformed function.
*/
LoweredFunc SkipAssert(LoweredFunc f);

/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
Expand Down
44 changes: 40 additions & 4 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,40 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const tvm::Array<tvm::PrimExpr>& required);

/*!
* \brief Combine context calls in the host function.
* \brief skip assert stmt.
*
* \return The pass.
*/
TVM_DLL Pass CombineContextCall();
TVM_DLL Pass SkipAssert();

/*!
* \brief Insert sync between parallel read/write of shared buffers.
*
* \param storage_scope The storage scope considered.
* \return The pass.
*/
TVM_DLL Pass ThreadSync(std::string storage_scope);


/*!
* \brief Lower cross thread alleduce.
*
* \return The pass.
*/
TVM_DLL Pass LowerThreadAllreduce();

/*!
* \brief Infer the TensorCore fragment infomation using tensor intrinsics
*
* \return The pass.
*/
TVM_DLL Pass InferFragment();

/*!
* \brief Lower builtin intrinsics.
* \return The pass.
*/
TVM_DLL Pass LowerTVMBuiltin();

/*!
* \brief Lower the target specific function intrinsics in each of the function.
Expand All @@ -72,6 +101,12 @@ TVM_DLL Pass CombineContextCall();
*/
TVM_DLL Pass LowerIntrin();

/*!
* \brief Lower warp memory access to low-level device related function calls.
* \return The pass.
*/
TVM_DLL Pass LowerWarpMemory();

/*!
* \brief Lower attached storage access information on device.
*
Expand All @@ -82,10 +117,11 @@ TVM_DLL Pass LowerIntrin();
TVM_DLL Pass LowerDeviceStorageAccessInfo();

/*!
* \brief Lower warp memory access to low-level device related function calls.
* \brief Combine context calls in the host function.
*
* \return The pass.
*/
TVM_DLL Pass LowerWarpMemory();
TVM_DLL Pass CombineContextCall();


/*!
Expand Down
56 changes: 37 additions & 19 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,15 @@ def _build_for_device(flist, target, target_host):
mdev : tvm.module
A module that contains device code.
"""
@tvm.tir.transform.prim_func_pass(opt_level=0)
class BindTarget:
def __init__(self, target):
self.target = target

# pylint: disable=unused-argument
def transform_function(self, func, mod, ctx):
return func.with_attr("target", self.target)

target = _target.create(target)
device_type = ndarray.context(target.target_name, 0).device_type
fhost = []
Expand Down Expand Up @@ -250,30 +259,39 @@ def _build_for_device(flist, target, target_host):
else:
raise ValueError("unknown function type %d" % func.func_type)

for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)

if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target)

fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]

if device_type == ndarray.cpu(0).device_type and target_host == target:
assert not fdevice

target_host = _target.create(target_host)
fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mdev = codegen.build_module(fdevice, str(target)) if fdevice else None

return fhost, mdev
# device optimizations
mod_dev = tvm.testing.LoweredFuncsToIRModule(fdevice)
opt_device = tvm.ir.transform.Sequential(
[BindTarget(target),
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin()])
mod_dev = opt_device(mod_dev)

# host optimizations
mod_host = tvm.testing.LoweredFuncsToIRModule(fhost)
opt_host = tvm.ir.transform.Sequential(
[BindTarget(target_host),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin(),
tvm.tir.transform.CombineContextCall()])
mod_host = opt_host(mod_host)

rt_mod_dev = codegen.build_module(mod_dev, target) if fdevice else None
return mod_host, rt_mod_dev


def build(inputs,
Expand Down Expand Up @@ -402,19 +420,19 @@ def build(inputs,
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"

fhost_all = []
mod_host_all = tvm.IRModule({})

device_modules = []
for tar, flist in target_flist.items():
fhost, mdev = _build_for_device(flist, tar, target_host)
# Save the current lowered functions of the host and the device module.
fhost_all += fhost
mod_host, mdev = _build_for_device(flist, tar, target_host)
mod_host_all.update(mod_host)
device_modules.append(mdev)

# Generate a unified host module.
mhost = codegen.build_module(fhost_all, str(target_host))
rt_mod_host = codegen.build_module(mod_host_all, target_host)

# Import all modules.
for mdev in device_modules:
if mdev:
mhost.import_module(mdev)
return mhost
rt_mod_host.import_module(mdev)
return rt_mod_host
12 changes: 7 additions & 5 deletions python/tvm/target/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@
# under the License.
"""Code generation related functions."""
from . import _ffi_api
from . import target as _tgt


def build_module(lowered_func, target):
"""Build lowered_func into Module.
def build_module(mod, target):
"""Build IRModule into Module.
Parameters
----------
lowered_func : LoweredFunc
The lowered function
mod : tvm.IRModule
The ir module.
target : str
The target module type.
Expand All @@ -35,7 +36,8 @@ def build_module(lowered_func, target):
module : runtime.Module
The corressponding module.
"""
return _ffi_api.Build(lowered_func, target)
target = _tgt.create(target) if isinstance(target, str) else target
return _ffi_api.Build(mod, target)


def llvm_lookup_intrinsic_id(name):
Expand Down
Loading

0 comments on commit 44bffdb

Please sign in to comment.