Skip to content

Commit

Permalink
[TIR][REFACTOR] Remove ir_pass in favor of analysis/transform. (apach…
Browse files Browse the repository at this point in the history
…e#5415)

This PR removes ir_pass(old style pass functions) in favor
of analysis/transform(new style pass manager).
  • Loading branch information
tqchen authored and trevor-m committed Jun 18, 2020
1 parent 0a57ae6 commit 6eb9e5b
Show file tree
Hide file tree
Showing 99 changed files with 698 additions and 853 deletions.
10 changes: 9 additions & 1 deletion include/tvm/te/schedule_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ TVM_DLL void AutoInlineInjective(Schedule sch);
*/
Map<IterVar, Range> InferBound(const Schedule& sch);

/*!
* \brief Verify if there is any argument bound to compact buffer.
*
* \param stmt The stmt to be verified.
* \return true if there is any buffer_bind_scope attribute found,
* otherwise, false.
*/
bool VerifyCompactBuffer(const Stmt& stmt);

/*!
* \brief Schedule s' dependent operations.
*
Expand All @@ -72,7 +81,6 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);


/*!
* \brief Try to modify the AST generated by ScheduleOps to support TensorCore.
*
Expand Down
105 changes: 101 additions & 4 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
#define TVM_TIR_ANALYSIS_H_

#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>

#include <string>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -59,7 +60,47 @@ struct ExprDeepEqual {
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);

/*!
* \brief Whether the expression have side effect.
* \param expr The expression to be checked.
* \return whether expression have side effect
*/
TVM_DLL bool HasSideEffect(const PrimExpr& expr);

/*!
* \brief Whether e expression used any var in variable set..
* \param expr The expression to be checked.
* \param vset_contains The check function to see if var is in the vset.
* \return Whether e uses vset.
*/
TVM_DLL bool ExprUseVar(const PrimExpr& expr,
std::function<bool(const VarNode*)> vset_contains);

/*!
* \brief Whether e expression used var.
* \param expr The expression to be checked.
* \param var The variable.
* \return Whether e uses v.
*/
inline bool ExprUseVar(const PrimExpr& expr, const Var& var) {
return ExprUseVar(expr, [&](const VarNode* node) {
return var.get() == node;
});
}


/*!
* \brief Verifies whether the IR stmt or Expr is in SSA form.
* That is: each Var is defined and assigned once(in Let/For)
*
* \param func The function to be verified.
* \return Whether IR is in SSA form.
*
* \note All passes in TIR consume and produce SSA form.
*/
TVM_DLL bool VerifySSA(const PrimFunc& func);

/*!
* \brief Verify if memory accesses are legal for a specific target device type.
Expand All @@ -68,11 +109,67 @@ Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal. This pass performs verification for this case.
*
* \param mod The module to be verified.
* \param func The function to be verified.
* \return Success of memory verification.
*/
void VerifyMemory(const IRModule& mod);
TVM_DLL bool VerifyMemory(const PrimFunc& func);

/*!
* \brief Verify the correctness of a GPU code
* It will check the whether the amount of memory usage or the number of threads
* in a block exceeds the limit
* \param func The function to be checked
* \param constraints The dict to specify constraints to check.
* Possible keys are
*
* "max_local_memory_per_block": Total amount of local memory per block (in bytes).
* "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
* "max_threads_per_block": Maximum number of threads per block.
* "max_thread_x": Maximum length of threadIdx.x.
* "max_thread_y": Maximum length of threadIdx.y.
* "max_thread_z": Maximum length of threadIdx.z.
*
* If one key is missing in this argument, the pass won't check for that item.
* \return valid Whether it is a valid GPU code
*
*/
TVM_DLL bool VerifyGPUCode(const PrimFunc& func,
Map<std::string, PrimExpr> constraints);

// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {

using tvm::transform::Pass;
using tvm::transform::PassContext;

/*!
* \brief Pass variant of VerifySSA.
*
* \returns The pass.
* \sa tvm::tir::VerifySSA
*/
TVM_DLL Pass VerifySSA();

/*!
* \brief Pass variant of VerifyMemory.
*
* \returns The pass.
* \sa tvm::tir::VerifyMemory
*/
TVM_DLL Pass VerifyMemory();

/*!
* \brief Pass variant of VerifyGPUCode.
*
* \param constraints The dict to specify constraints to check.
*
* \returns The pass.
* \sa tvm::tir::VerifyGPUCode
*/
TVM_DLL Pass VerifyGPUCode(Map<std::string, PrimExpr> constraints);

} // namespace transform
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_ANALYSIS_H_
145 changes: 0 additions & 145 deletions include/tvm/tir/ir_pass.h

This file was deleted.

19 changes: 17 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ TVM_DLL Pass InstrumentBoundCheckers();
*/
TVM_DLL Pass MakePackedAPI(int num_unpacked_args);


/*!
* \brief Remap the thread axis
*
Expand All @@ -241,7 +240,6 @@ TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
*/
TVM_DLL Pass RemapThreadAxis(Map<runtime::String, IterVar> axis_map);


/*!
* \brief Lower custom datatypes.
*
Expand All @@ -251,6 +249,13 @@ TVM_DLL Pass RemapThreadAxis(Map<runtime::String, IterVar> axis_map);
*/
TVM_DLL Pass LowerCustomDatatypes();

/*!
* \brief Decorate all the function's body as device function.
*
* \return The pass.
*/
TVM_DLL Pass DecorateDeviceScope();

/*!
* \brief Split the function into a host function and device functions.
*
Expand Down Expand Up @@ -334,6 +339,16 @@ TVM_DLL Pass CombineContextCall();
*/
TVM_DLL Pass NarrowDataType(int target_bits);

/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
* to avoid pointer casting in backend when possible.
*
* \return The pass.
*/
TVM_DLL Pass PointerValueTypeRewrite();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

import tvm._ffi
from tvm import nd, rpc as _rpc, target as _target
from tvm.tir import ir_pass
from tvm.error import TVMError
from tvm.target import build_config
from tvm.driver import build
Expand Down Expand Up @@ -616,7 +615,7 @@ def gpu_verify_pass(**kwargs):
This pass will check memory usage and number of threads per block.
"""
def verify_pass(f, *_):
valid = ir_pass.VerifyGPUCode(f.body, kwargs)
valid = tvm.analysis.verify_gpu_code(f, kwargs)
if not valid:
raise InstantiationError("Skipped because of invalid gpu kernel")
return f
Expand Down
6 changes: 2 additions & 4 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from tvm.ir import container
from tvm.ir import CallingConv
from tvm.target import codegen, BuildConfig
from tvm.tir import ir_pass
from tvm.te import tensor
from tvm.te import schedule
from tvm import target as _target
Expand Down Expand Up @@ -111,7 +110,7 @@ def form_irmodule(sch, args, name, binds):
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)

compact = ir_pass.VerifyCompactBuffer(stmt)
compact = schedule.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)

stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)
Expand Down Expand Up @@ -246,9 +245,8 @@ def _build_for_device(input_mod, target, target_host):

mod_mixed = input_mod
mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)
tvm.tir.analysis.verify_memory(mod_mixed)

opt_mixed = []
opt_mixed = [tvm.tir.transform.VerifyMemory()]
if len(mod_mixed.functions) == 1:
opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]
if BuildConfig.current().detect_global_barrier:
Expand Down
Loading

0 comments on commit 6eb9e5b

Please sign in to comment.