diff --git a/include/tvm/te/schedule_pass.h b/include/tvm/te/schedule_pass.h index c06ab50c4423..618fc229d2f3 100644 --- a/include/tvm/te/schedule_pass.h +++ b/include/tvm/te/schedule_pass.h @@ -59,6 +59,15 @@ TVM_DLL void AutoInlineInjective(Schedule sch); */ Map InferBound(const Schedule& sch); +/*! + * \brief Verify if there is any argument bound to compact buffer. + * + * \param stmt The stmt to be verified. + * \return true if there is any buffer_bind_scope attribute found, + * otherwise, false. + */ +bool VerifyCompactBuffer(const Stmt& stmt); + /*! * \brief Schedule s' dependent operations. * @@ -72,7 +81,6 @@ Map InferBound(const Schedule& sch); */ Stmt ScheduleOps(Schedule s, Map dom_map, bool debug_keep_trivial_loop); - /*! * \brief Try to modify the AST generated by ScheduleOps to support TensorCore. * diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 5c4990a9ba92..f7a89f50ef61 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -25,10 +25,11 @@ #define TVM_TIR_ANALYSIS_H_ #include +#include #include #include #include - +#include namespace tvm { namespace tir { @@ -59,7 +60,47 @@ struct ExprDeepEqual { * \param defs The vars that is defined. * \return Array of undefined vars. */ -Array UndefinedVars(const Stmt& stmt, const Array& defs); +TVM_DLL Array UndefinedVars(const Stmt& stmt, const Array& defs); + +/*! + * \brief Whether the expression have side effect. + * \param expr The expression to be checked. + * \return whether expression have side effect + */ +TVM_DLL bool HasSideEffect(const PrimExpr& expr); + +/*! + * \brief Whether e expression used any var in variable set.. + * \param expr The expression to be checked. + * \param vset_contains The check function to see if var is in the vset. + * \return Whether e uses vset. + */ +TVM_DLL bool ExprUseVar(const PrimExpr& expr, + std::function vset_contains); + +/*! + * \brief Whether e expression used var. + * \param expr The expression to be checked. + * \param var The variable. + * \return Whether e uses v. + */ +inline bool ExprUseVar(const PrimExpr& expr, const Var& var) { + return ExprUseVar(expr, [&](const VarNode* node) { + return var.get() == node; + }); +} + + +/*! + * \brief Verifies whether the IR stmt or Expr is in SSA form. + * That is: each Var is defined and assigned once(in Let/For) + * + * \param func The function to be verified. + * \return Whether IR is in SSA form. + * + * \note All passes in TIR consume and produce SSA form. + */ +TVM_DLL bool VerifySSA(const PrimFunc& func); /*! * \brief Verify if memory accesses are legal for a specific target device type. @@ -68,11 +109,67 @@ Array UndefinedVars(const Stmt& stmt, const Array& defs); * threads, CPU code is generated that tries to access GPU memory, * which is illegal. This pass performs verification for this case. * - * \param mod The module to be verified. + * \param func The function to be verified. * \return Success of memory verification. */ -void VerifyMemory(const IRModule& mod); +TVM_DLL bool VerifyMemory(const PrimFunc& func); + +/*! + * \brief Verify the correctness of a GPU code + * It will check the whether the amount of memory usage or the number of threads + * in a block exceeds the limit + * \param func The function to be checked + * \param constraints The dict to specify constraints to check. + * Possible keys are + * + * "max_local_memory_per_block": Total amount of local memory per block (in bytes). + * "max_shared_memory_per_block": Total amount of shared memory per block (in bytes). + * "max_threads_per_block": Maximum number of threads per block. + * "max_thread_x": Maximum length of threadIdx.x. + * "max_thread_y": Maximum length of threadIdx.y. + * "max_thread_z": Maximum length of threadIdx.z. + * + * If one key is missing in this argument, the pass won't check for that item. + * \return valid Whether it is a valid GPU code + * + */ +TVM_DLL bool VerifyGPUCode(const PrimFunc& func, + Map constraints); + +// Pass variants of verification analysis +// directly throws RuntimeError when verification fails. +namespace transform { + +using tvm::transform::Pass; +using tvm::transform::PassContext; + +/*! + * \brief Pass variant of VerifySSA. + * + * \returns The pass. + * \sa tvm::tir::VerifySSA + */ +TVM_DLL Pass VerifySSA(); + +/*! + * \brief Pass variant of VerifyMemory. + * + * \returns The pass. + * \sa tvm::tir::VerifyMemory + */ +TVM_DLL Pass VerifyMemory(); + +/*! + * \brief Pass variant of VerifyGPUCode. + * + * \param constraints The dict to specify constraints to check. + * + * \returns The pass. + * \sa tvm::tir::VerifyGPUCode + */ +TVM_DLL Pass VerifyGPUCode(Map constraints); +} // namespace transform } // namespace tir } // namespace tvm #endif // TVM_TIR_ANALYSIS_H_ diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h deleted file mode 100644 index 5dd080b0904f..000000000000 --- a/include/tvm/tir/ir_pass.h +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/tir/ir_pass.h - * \brief Collection of IR pass functions - * - * When the pass functions in this file are for Stmt, - * we can use PassFunction(Evaluate(expr)) to apply it to Expr - */ -#ifndef TVM_TIR_IR_PASS_H_ -#define TVM_TIR_IR_PASS_H_ - -#include -#include -#include -#include - -#include -#include -#include -#include - - -namespace tvm { -namespace tir { - - -/*! - * \brief verifies whether the IR stmt or Expr is in SSA form. - * That is: each VarExpr is defined and assigned once(in Let/For) - * - * \param ir The root of the IR DAG. - * \return Whether IR is in SSA form. - * \note All the passes in this file uses SSA form and outputs SSA form. - */ -TVM_DLL bool VerifySSA(const Stmt& ir); - -/*! - * \brief Whether the expression have side effect. - * \return whether expression have side effect - */ -TVM_DLL bool HasSideEffect(const PrimExpr& e); - -/*! - * \brief Whether e expression used var. - * \param e The expression to be checked. - * \param v The variable. - * \return Whether e uses v. - */ -bool ExprUseVar(const PrimExpr& e, const Var& v); - -/*! - * \brief Whether e expression used any var in variable set.. - * \param e The expression to be checked. - * \param vset The variable set. - * \return Whether e uses vset. - */ -bool ExprUseVar(const PrimExpr& e, const std::unordered_set& vset); - -/*! - * \brief Convert a IR node to be SSA form. - * \param stmt The source statement to be converted. - * \return The converted form. - */ -TVM_DLL Stmt ConvertSSA(Stmt stmt); - -/*! - * \brief Verify if there is any argument bound to compact buffer. - * - * \param stmt The stmt to be verified. - * \return true if there is any buffer_bind_scope attribute found, - * otherwise, false. - */ -bool VerifyCompactBuffer(Stmt stmt); - -/*! - * \brief Decorate the stmt with a device scope, this is helpful for - * hardware accelerator without thread blocks. - * - * \param stmt The stmt to be transformed - * \return Transformed stmt. - */ -Stmt DecorateDeviceScope(Stmt stmt); - -/*! - * \brief Loop invariant code motion which locates and hoists if statements. - * \param stmt The stmt to do if statement hoisting. - * \return Transformed stmt. - */ -Stmt HoistIfThenElse(Stmt stmt); - -/*! - * \brief Rewrite the pointer content type of arguments, - * as well as Alloc internal to the function to use - * the most frequently accessed type for load/store - * to avoid pointer casting in backend when possible. - * - * \note implemeneted in storage_rewrite.cc - * \param f The function to be trasnformed - * \return Transformed function. - */ -PrimFunc PointerValueTypeRewrite(PrimFunc f); - -/*! - * \brief Verify the correctness of a GPU code - * It will check the whether the amount of memory usage or the number of threads - * in a block exceeds the limit - * \param stmt The statement to be checked - * \param constraints The dict to specify constraints to check. - * Possible keys are - * - * "max_local_memory_per_block": Total amount of local memory per block (in bytes). - * "max_shared_memory_per_block": Total amount of shared memory per block (in bytes). - * "max_threads_per_block": Maximum number of threads per block. - * "max_thread_x": Maximum length of threadIdx.x. - * "max_thread_y": Maximum length of threadIdx.y. - * "max_thread_z": Maximum length of threadIdx.z. - * - * If one key is missing in this argument, the pass won't check for that item. - * \return valid Whether it is a valid GPU code - * - */ -bool VerifyGPUCode(Stmt stmt, - Map constraints); - -} // namespace tir -} // namespace tvm -#endif // TVM_TIR_IR_PASS_H_ diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 09ea09731f51..abf8b1ce9965 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -228,7 +228,6 @@ TVM_DLL Pass InstrumentBoundCheckers(); */ TVM_DLL Pass MakePackedAPI(int num_unpacked_args); - /*! * \brief Remap the thread axis * @@ -241,7 +240,6 @@ TVM_DLL Pass MakePackedAPI(int num_unpacked_args); */ TVM_DLL Pass RemapThreadAxis(Map axis_map); - /*! * \brief Lower custom datatypes. * @@ -251,6 +249,13 @@ TVM_DLL Pass RemapThreadAxis(Map axis_map); */ TVM_DLL Pass LowerCustomDatatypes(); +/*! + * \brief Decorate all the function's body as device function. + * + * \return The pass. + */ +TVM_DLL Pass DecorateDeviceScope(); + /*! * \brief Split the function into a host function and device functions. * @@ -334,6 +339,16 @@ TVM_DLL Pass CombineContextCall(); */ TVM_DLL Pass NarrowDataType(int target_bits); +/*! + * \brief Rewrite the pointer content type of arguments, + * as well as Alloc internal to the function to use + * the most frequently accessed type for load/store + * to avoid pointer casting in backend when possible. + * + * \return The pass. + */ +TVM_DLL Pass PointerValueTypeRewrite(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 5ddc5df7d1f5..bddb42034a50 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -35,7 +35,6 @@ import tvm._ffi from tvm import nd, rpc as _rpc, target as _target -from tvm.tir import ir_pass from tvm.error import TVMError from tvm.target import build_config from tvm.driver import build @@ -616,7 +615,7 @@ def gpu_verify_pass(**kwargs): This pass will check memory usage and number of threads per block. """ def verify_pass(f, *_): - valid = ir_pass.VerifyGPUCode(f.body, kwargs) + valid = tvm.analysis.verify_gpu_code(f, kwargs) if not valid: raise InstantiationError("Skipped because of invalid gpu kernel") return f diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index dcd6d444f02d..216cad992d98 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -26,7 +26,6 @@ from tvm.ir import container from tvm.ir import CallingConv from tvm.target import codegen, BuildConfig -from tvm.tir import ir_pass from tvm.te import tensor from tvm.te import schedule from tvm import target as _target @@ -111,7 +110,7 @@ def form_irmodule(sch, args, name, binds): bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) - compact = ir_pass.VerifyCompactBuffer(stmt) + compact = schedule.VerifyCompactBuffer(stmt) binds, arg_list = get_binds(args, compact, binds) stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds) @@ -246,9 +245,8 @@ def _build_for_device(input_mod, target, target_host): mod_mixed = input_mod mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) - tvm.tir.analysis.verify_memory(mod_mixed) - opt_mixed = [] + opt_mixed = [tvm.tir.transform.VerifyMemory()] if len(mod_mixed.functions) == 1: opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))] if BuildConfig.current().detect_global_barrier: diff --git a/python/tvm/runtime/packed_func.py b/python/tvm/runtime/packed_func.py index a04e32be0ea2..af4265a66ad1 100644 --- a/python/tvm/runtime/packed_func.py +++ b/python/tvm/runtime/packed_func.py @@ -44,8 +44,6 @@ class PackedFunc(PackedFuncBase): The compiled module returns Function. TVM backend also registers and exposes its API as Functions. - For example, the developer function exposed in tvm.ir_pass are actually - C++ functions that are registered as PackedFunc The following are list of common usage scenario of tvm.runtime.PackedFunc. diff --git a/python/tvm/target/build_config.py b/python/tvm/target/build_config.py index 8aae6be54a8b..538ee7d5f544 100644 --- a/python/tvm/target/build_config.py +++ b/python/tvm/target/build_config.py @@ -20,83 +20,9 @@ import tvm.ir from tvm.runtime import Object -from tvm.tir import Stmt from . import _ffi_api -class DumpIR(object): - """ - Dump IR for each pass. - With it, you can dump ir just like gcc/llvm. - - How to use: - ----------- - .. code-block:: python - - with tvm.target.build_config(dump_pass_ir=True) - run() - """ - scope_level = 0 - def __init__(self): - self._pass_id = 0 - self._recover_list = [] - - def decorate(self, func): - """ decorate the pass function""" - def dump(*args, **kwargs): - """dump function""" - retv = func(*args, **kwargs) - if not isinstance(retv, (Stmt,)): - return retv - fname = func.func_name if hasattr(func, 'func_name') else func.__name__ - pname = str(self._pass_id) + "_" + fname + "_ir.cc" - with open(pname, "a") as f: - out = retv - f.write(str(out)) - self._pass_id += 1 - return retv - return dump - - def decorate_irpass(self): - """decorate ir_pass and ScheduleOps""" - self._old_sgpass = tvm.te.schedule.ScheduleOps - tvm.te.schedule.ScheduleOps = self.decorate(tvm.te.schedule.ScheduleOps) - vset = vars(tvm.tir.ir_pass) - k = v = 0 - def recover(): - vset[k] = v - for k, v in vset.items(): - self._recover_list.append(recover) - vset[k] = self.decorate(v) if isinstance(v, tvm.runtime.PackedFunc) else v - - def decorate_custompass(self, custom_pass): - """decorate given list of custom passes, and return decorated passes""" - custom_pass = custom_pass if custom_pass else [] - pass_list = [] - for idx, x in enumerate(custom_pass): - x[1].__name__ = "custom{}_phase{}".format(idx, x[0]) - pass_list += [(x[0], self.decorate(x[1]))] - return pass_list - - def enter(self): - """only decorate outermost nest""" - if DumpIR.scope_level > 0: - return - self.decorate_irpass() - self._pass_id = 0 - DumpIR.scope_level += 1 - - def exit(self): - """recover outermost nest""" - if DumpIR.scope_level > 1: - return - # recover decorated functions - for f in self._recover_list: - f() - tvm.te.schedule.ScheduleOps = self._old_sgpass - DumpIR.scope_level -= 1 - - @tvm._ffi.register_object class BuildConfig(Object): """Configuration scope to set a build config option. @@ -129,7 +55,6 @@ class BuildConfig(Object): "disable_vectorize": False, "disable_assert": False } - _dump_ir = DumpIR() # pylint: disable=no-member def __init__(self, handle): @@ -163,13 +88,9 @@ def add_lower_pass(self, value): def __enter__(self): # pylint: disable=protected-access _ffi_api.EnterBuildConfigScope(self) - if self.dump_pass_ir: - BuildConfig._dump_ir.enter() return self def __exit__(self, ptype, value, trace): - if self.dump_pass_ir: - BuildConfig._dump_ir.exit() _ffi_api.ExitBuildConfigScope(self) def __setattr__(self, name, value): diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 6a62505a3034..7d06eea9632e 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -45,7 +45,6 @@ from .op import comm_reducer, min, max, sum from . import ir_builder -from . import ir_pass from . import transform from . import analysis from . import stmt_functor diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 448d0e6c5f8e..1a3eb4806677 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -57,12 +57,52 @@ def expr_deep_equal(lhs, rhs): return _ffi_api.expr_deep_equal(lhs, rhs) -def verify_memory(mod): +def verify_ssa(func): + """Verify if the func is in SSA form. + + Parameters + ---------- + func: tvm.tir.PrimFunc + The module to be verified. + + Returns + ------- + result : bool + The result of verification. + """ + return _ffi_api.verify_ssa(func) + + +def verify_memory(func): + """Verify if func contains illegal host side direct memory access. + + Parameters + ---------- + func: tvm.tir.PrimFunc + The module to be verified. + + Returns + ------- + result : bool + The result of verification. + """ + return _ffi_api.verify_memory(func) + + +def verify_gpu_code(func, constraints): """Verify if module contains illegal host side direct memory access. Parameters ---------- - mod: tvm.IRModule + func: tvm.tir.PrimFunc The module to be verified. + + constraints : Dict[str, int] + The attribute constraints. + + Returns + ------- + result : bool + The result of verification. """ - _ffi_api.verify_memory(mod) + return _ffi_api.verify_gpu_code(func, constraints) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index bb39c1f69131..6d797f8772ec 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -330,6 +330,17 @@ def SplitHostDevice(): return _ffi_api.SplitHostDevice() +def DecorateDeviceScope(): + """Decorate all the function's body as device function. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.DecorateDeviceScope() + + def SkipAssert(): """Skip assert stmt. @@ -456,3 +467,14 @@ def NarrowDataType(target_bits): Run this pass after StorageFlatten. """ return _ffi_api.NarrowDataType(target_bits) + + +def VerifyMemory(): + """Verify if func contains illegal host side direct memory access. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.VerifyMemory() diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index 26be5d51115f..eeaaa8af0ae5 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -23,7 +23,6 @@ */ #include #include -#include #include #include diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index 58723170b3ca..c7f90f535e29 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -23,7 +23,8 @@ */ #include #include -#include +#include +#include #include #include #include @@ -156,10 +157,14 @@ Array DetectLinearEquation(const PrimExpr& e, } std::unordered_set vset; + auto vset_contains = [&](const VarNode* node) { + return vset.count(node) != 0; + }; + for (size_t i = vars.size(); i > 1; --i) { vset.insert(vars[i - 1].get()); // The previous coeff contains the variable - if (ExprUseVar(coeff[i - 2], vset)) { + if (ExprUseVar(coeff[i - 2], vset_contains)) { return Array(); } } diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 2467e758c3e4..81443db38d47 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -22,7 +22,6 @@ * \brief Utility to deduce bound of expression */ #include -#include #include #include #include diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 6e653cec3c3b..0ae98412b017 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -20,7 +20,7 @@ /*! * \file tvm/arith/ir_mutator_with_analyzer.cc */ -#include +#include #include #include "ir_mutator_with_analyzer.h" diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index 394e5db9c93e..f6004e2ad9b9 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -49,11 +49,11 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { using StmtExprMutator::VisitExpr_; // override functions that need to populate the context information. - Stmt VisitStmt_(const tir::ForNode* op) override; - Stmt VisitStmt_(const tir::LetStmtNode* op) override; - Stmt VisitStmt_(const tir::IfThenElseNode* op) override; - Stmt VisitStmt_(const tir::AttrStmtNode* op) override; - Stmt VisitStmt_(const tir::AssertStmtNode* op) override; + tir::Stmt VisitStmt_(const tir::ForNode* op) override; + tir::Stmt VisitStmt_(const tir::LetStmtNode* op) override; + tir::Stmt VisitStmt_(const tir::IfThenElseNode* op) override; + tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) override; + tir::Stmt VisitStmt_(const tir::AssertStmtNode* op) override; PrimExpr VisitExpr_(const tir::LetNode* op) override; PrimExpr VisitExpr_(const tir::SelectNode* op) override; PrimExpr VisitExpr_(const tir::CallNode* op) override; diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 0920ed3d0712..14cfbd6f57f9 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -65,8 +65,8 @@ #ifndef TVM_ARITH_PATTERN_MATCH_H_ #define TVM_ARITH_PATTERN_MATCH_H_ -#include #include +#include #include #include "const_fold.h" @@ -149,9 +149,9 @@ class PEqualChecker { }; template<> -class PEqualChecker { +class PEqualChecker { public: - bool operator()(const Var& lhs, const Var& rhs) const { + bool operator()(const tir::Var& lhs, const tir::Var& rhs) const { return lhs.same_as(rhs); } }; diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7e2ef701265c..849c74028d63 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -27,7 +27,6 @@ #include #include -#include #include #include #include @@ -142,7 +141,7 @@ IRModule lower(te::Schedule sch, // Before TIR transformation. auto bounds = te::InferBound(sch); auto stmt = te::ScheduleOps(sch, bounds, false); - bool compact = tir::VerifyCompactBuffer(stmt); + bool compact = te::VerifyCompactBuffer(stmt); Map out_binds; GetBinds(args, compact, binds, &out_binds, &out_arg_list, config); @@ -196,10 +195,10 @@ split_dev_host_funcs(IRModule mod_mixed, const Target& target, const Target& target_host, const BuildConfig& config) { - mod_mixed = BindTarget(target)(std::move(mod_mixed)); - tir::VerifyMemory(mod_mixed); - - Array mixed_pass_list = {BindTarget(target)}; + Array mixed_pass_list = { + BindTarget(target), + tir::transform::VerifyMemory() + }; if (config->detect_global_barrier) { mixed_pass_list.push_back(tir::transform::ThreadSync("global")); } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index e2d5e93fa5c1..c26228ebeba3 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -28,7 +28,6 @@ #include #include #include -#include #include #include "../../target/source/codegen_source_base.h" diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 65e6ae9e79c6..5c98728b1c25 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -32,7 +32,6 @@ #include #include #include -#include #include #include diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 66dab57fd947..b3e1772e904f 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -22,7 +22,6 @@ * \brief Convolution operators */ #include -#include #include #include #include diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 6c5aebe2bd4c..f33cd7e4d7c7 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -24,7 +24,6 @@ #ifndef TVM_RELAY_OP_NN_CONVOLUTION_H_ #define TVM_RELAY_OP_NN_CONVOLUTION_H_ -#include #include #include diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 0eceea81da17..262abf85a20c 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -25,7 +25,6 @@ #include #include -#include #include #include diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index dde842765c78..e474b9cf6bcf 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -23,7 +23,6 @@ #ifdef TVM_LLVM_VERSION #include -#include #include #include #include diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 5c7ca6fb622f..e851f37901a2 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -45,7 +45,7 @@ #include "llvm_common.h" #include "../../runtime/thread_storage_scope.h" #include "../../arith/compute_expr.h" -#include "../../tir/pass/ir_util.h" +#include "../../tir/transforms/ir_util.h" namespace tvm { namespace codegen { diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 161c1ca3bab1..0f73e1b59060 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -24,7 +24,7 @@ // Use libspirv for parsing and validating code. #include #include -#include +#include #include "codegen_spirv.h" #include "../build_common.h" @@ -80,6 +80,8 @@ runtime::Module BuildSPIRV(IRModule mod) { const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc"); + mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); + CodeGenSPIRV cg; for (auto kv : mod->functions) { @@ -94,7 +96,7 @@ runtime::Module BuildSPIRV(IRModule mod) { << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); - f = PointerValueTypeRewrite(std::move(f)); + VulkanShader shader; shader.data = cg.BuildFunction(f); diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 5d05b0811db4..be058b7306eb 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -22,7 +22,6 @@ * \brief Generate SPIRV block */ #include -#include #include #include #include "codegen_spirv.h" diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index f50760711dec..b51e8edf027a 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index 9ebb89ee95a6..d5b6fec3698a 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -27,7 +27,6 @@ #include #include #include -#include #include #include "ad_util.h" diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 6f703c9ec4e3..2d9f13baedac 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -4,7 +4,7 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance +5B * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include #include @@ -630,15 +630,20 @@ Stmt TransformUpdate(const Stage& stage, banned.insert(iv->var.get()); } } + + auto fbanned = [&](const VarNode* node) { + return banned.count(node); + }; + for (const PrimExpr& pred : n.main_predicates) { - if (tir::ExprUseVar(pred, banned)) { + if (tir::ExprUseVar(pred, fbanned)) { LOG(FATAL) << "Tensorize update transform failed, the condition " << pred << " has a conflict with the reset condition"; } } return IfThenElseNode::make(arith::ComputeReduce(conds, const_true(1)), - update, body); + update, body); } } // namespace te diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index 1b3d87d57006..1ec17e9e38a9 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -21,7 +21,6 @@ * \brief Logics related to cross thread reduction, used by ComputeOpNode. * \file cross_thread_reduction.cc */ -#include #include "compute_op.h" #include "op_util.h" diff --git a/src/te/operation/hybrid_op.h b/src/te/operation/hybrid_op.h index a7b2cb16c080..dadfecd3bdef 100644 --- a/src/te/operation/hybrid_op.h +++ b/src/te/operation/hybrid_op.h @@ -32,8 +32,8 @@ #include #include "../schedule/message_passing.h" -#include "../../tir/pass/ir_util.h" -#include "../../tir/pass/arg_binder.h" +#include "../../tir/transforms/ir_util.h" +#include "../../tir/transforms/arg_binder.h" namespace tvm { namespace te { diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index d022134065c6..f7e0e51fd16a 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -22,7 +22,6 @@ * \file op_util.cc */ #include -#include #include #include #include diff --git a/src/te/operation/op_util.h b/src/te/operation/op_util.h index fbe2e0a95d79..f95f84ac4d01 100644 --- a/src/te/operation/op_util.h +++ b/src/te/operation/op_util.h @@ -29,8 +29,8 @@ #include #include #include -#include "../../tir/pass/ir_util.h" -#include "../../tir/pass/arg_binder.h" +#include "../../tir/transforms/ir_util.h" +#include "../../tir/transforms/arg_binder.h" #include "../schedule/message_passing.h" namespace tvm { diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 1916b4a4823e..49929282efb3 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include "op_util.h" #include "../schedule/graph.h" diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index b66406969c76..31d4b368ad89 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -23,7 +23,6 @@ */ #include #include -#include #include #include @@ -144,14 +143,19 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, } } } + + auto fbanned = [&](const VarNode* node) { + return banned.count(node); + }; + for (const PrimExpr& pred : n.main_predicates) { - if (tir::ExprUseVar(pred, banned)) { + if (tir::ExprUseVar(pred, fbanned)) { LOG(FATAL) << "Tensorize failed, split condition " << pred << " relies on var defined inside tensorize scope"; } } for (const PrimExpr& pred : n.init_predicates) { - if (tir::ExprUseVar(pred, banned)) { + if (tir::ExprUseVar(pred, fbanned)) { LOG(FATAL) << "Tensorize failed, split condition " << pred << " relies on var defined inside tensorize scope"; } diff --git a/src/te/schedule/bound.cc b/src/te/schedule/bound.cc index 50cbafd2b654..4dde945baf8c 100644 --- a/src/te/schedule/bound.cc +++ b/src/te/schedule/bound.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include "graph.h" diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 6ed9438ec90f..1453ed0683e4 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -23,7 +23,6 @@ */ #include #include -#include #include "message_passing.h" #include "../../arith/compute_expr.h" diff --git a/src/te/schedule/operation_inline.cc b/src/te/schedule/operation_inline.cc index 778899c36333..c3f333e522c8 100644 --- a/src/te/schedule/operation_inline.cc +++ b/src/te/schedule/operation_inline.cc @@ -22,10 +22,12 @@ */ #include #include -#include +#include #include #include #include "operation_inline.h" +#include "../../tir/transforms/ir_util.h" + namespace tvm { namespace te { diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index 48a27d17b700..f3e76a45e7db 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -23,12 +23,11 @@ #include #include #include -#include #include #include "message_passing.h" #include "operation_inline.h" -#include "../../tir/pass/ir_util.h" +#include "../../tir/transforms/ir_util.h" #include "../../arith/compute_expr.h" namespace tvm { diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index c818218fa65e..bdb77b6ba472 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -22,7 +22,7 @@ */ #include #include -#include +#include #include #include #include @@ -31,7 +31,7 @@ #include #include "graph.h" #include "../operation/op_util.h" -#include "../../tir/pass/ir_util.h" +#include "../../tir/transforms/ir_util.h" namespace tvm { namespace te { diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index 797a8b2b7b88..2198827279ac 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -29,7 +29,6 @@ #include #include #include -#include #include #include #include diff --git a/src/tir/pass/verify_compact_buffer.cc b/src/te/schedule/verify_compact_buffer.cc similarity index 83% rename from src/tir/pass/verify_compact_buffer.cc rename to src/te/schedule/verify_compact_buffer.cc index 5328165ffb91..759adb9e76fc 100644 --- a/src/tir/pass/verify_compact_buffer.cc +++ b/src/te/schedule/verify_compact_buffer.cc @@ -21,16 +21,18 @@ * \file verify_compact_buffer.cc * \brief Verify if there was any compact buffer bound to a statement. */ +#include #include #include #include #include #include +#include #include namespace tvm { -namespace tir { +namespace te { class VerifyBuffer : public StmtVisitor { public: @@ -41,7 +43,7 @@ class VerifyBuffer : public StmtVisitor { void VisitStmt_(const AttrStmtNode* op) final { StmtVisitor::VisitStmt_(op); - if (op->attr_key == attr::buffer_bind_scope) { + if (op->attr_key == tir::attr::buffer_bind_scope) { is_compact_ = true; } } @@ -50,10 +52,13 @@ class VerifyBuffer : public StmtVisitor { bool is_compact_{false}; }; -bool VerifyCompactBuffer(Stmt stmt) { +bool VerifyCompactBuffer(const Stmt& stmt) { VerifyBuffer verifier; return verifier.Verify(stmt); } -} // namespace tir +TVM_REGISTER_GLOBAL("schedule.VerifyCompactBuffer") +.set_body_typed(VerifyCompactBuffer); + +} // namespace te } // namespace tvm diff --git a/src/tir/pass/detect_device.cc b/src/tir/analysis/side_effect.cc similarity index 59% rename from src/tir/pass/detect_device.cc rename to src/tir/analysis/side_effect.cc index ee3a2e23b487..10039d9c1f11 100644 --- a/src/tir/pass/detect_device.cc +++ b/src/tir/analysis/side_effect.cc @@ -18,20 +18,38 @@ */ /*! - * \file detect_device.cc + * \file side_effect.cc + * \brief side effect analysis */ - -#include -#include "ir_util.h" +#include +#include +#include namespace tvm { namespace tir { -Stmt DecorateDeviceScope(Stmt stmt) { - Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)), - tir::attr::device_scope, - 0, - stmt); - return body; + +class ExprSideEffect : public ExprVisitor { + public: + void VisitExpr(const PrimExpr& e) final { + if (has_side_effect_) return; + ExprVisitor::VisitExpr(e); + } + + void VisitExpr_(const CallNode* op) final { + if (!op->is_pure()) { + has_side_effect_ = true; return; + } else { + ExprVisitor::VisitExpr_(op); + } + } + + bool has_side_effect_{false}; +}; + +bool HasSideEffect(const PrimExpr& e) { + ExprSideEffect v; + v(e); + return v.has_side_effect_; } } // namespace tir diff --git a/src/tir/pass/simple_passes.cc b/src/tir/analysis/var_touch.cc similarity index 51% rename from src/tir/pass/simple_passes.cc rename to src/tir/analysis/var_touch.cc index a7dc12f8a87f..ffc7792a15c7 100644 --- a/src/tir/pass/simple_passes.cc +++ b/src/tir/analysis/var_touch.cc @@ -18,44 +18,22 @@ */ /*! - * \file simple_passes.cc + * \file simple_analysis.cc * \brief Implementation of simple passes */ #include #include -#include +#include namespace tvm { namespace tir { -class IRSideEffect : public ExprVisitor { - public: - void VisitExpr(const PrimExpr& e) final { - if (has_side_effect_) return; - ExprVisitor::VisitExpr(e); - } - - void VisitExpr_(const CallNode* op) final { - if (!op->is_pure()) { - has_side_effect_ = true; return; - } else { - ExprVisitor::VisitExpr_(op); - } - } - - bool has_side_effect_{false}; -}; - -bool HasSideEffect(const PrimExpr& e) { - IRSideEffect v; - v(e); - return v.has_side_effect_; -} - - - class VarTouchVisitor : public ExprVisitor { public: + explicit VarTouchVisitor( + std::function var_set) + : var_set_(var_set) {} + void VisitExpr(const PrimExpr& e) final { if (use_var_) return; ExprVisitor::VisitExpr(e); @@ -70,45 +48,20 @@ class VarTouchVisitor : public ExprVisitor { ExprVisitor::VisitExpr_(op); } - virtual void Handle(const VarNode* var) = 0; - - bool use_var_{false}; -}; - -class ExprUseVarVisitor : public VarTouchVisitor { - public: - explicit ExprUseVarVisitor(const VarNode* var) - : var_(var) {} - - void Handle(const VarNode* var) final { - if (var == var_) use_var_ = true; + void Handle(const VarNode* var) { + if (var_set_(var)) use_var_ = true; } - private: - const VarNode* var_; -}; -class ExprUseVSetVisitor : public VarTouchVisitor { - public: - explicit ExprUseVSetVisitor( - const std::unordered_set& vset) - : vset_(vset) {} + bool use_var_{false}; - void Handle(const VarNode* var) final { - if (vset_.count(var)) use_var_ = true; - } private: - const std::unordered_set& vset_; + std::function var_set_; }; -bool ExprUseVar(const PrimExpr& e, const Var& v) { - ExprUseVarVisitor visitor(v.get()); - visitor(e); - return visitor.use_var_; -} bool ExprUseVar(const PrimExpr& e, - const std::unordered_set& vset) { - ExprUseVSetVisitor visitor(vset); + std::function var_set) { + VarTouchVisitor visitor(var_set); visitor(e); return visitor.use_var_; } diff --git a/src/tir/pass/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc similarity index 89% rename from src/tir/pass/verify_gpu_code.cc rename to src/tir/analysis/verify_gpu_code.cc index 70d909a859cc..3dd15002ea73 100644 --- a/src/tir/pass/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -25,7 +25,7 @@ */ #include - +#include #include #include #include @@ -164,7 +164,7 @@ class GPUCodeVerifier : public StmtVisitor { } }; -bool VerifyGPUCode(Stmt stmt, +bool VerifyGPUCode(const PrimFunc& func, Map constraints) { GPUCodeVerifier verifier; @@ -193,7 +193,7 @@ bool VerifyGPUCode(Stmt stmt, LOG(FATAL) << "Invalid check item: " << iter.first; } - return verifier.Verify(stmt, + return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block, max_threads_per_block, @@ -202,5 +202,30 @@ bool VerifyGPUCode(Stmt stmt, max_thread_z); } + +TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code") +.set_body_typed(VerifyGPUCode); + +namespace transform { + +Pass VerifyGPUCode(Map constraints) { + auto pass_func = [=](IRModule mod, PassContext ctx) { + for (auto kv : mod->functions) { + if (auto* n = kv.second.as()) { + auto func = GetRef(n); + CHECK(VerifyGPUCode(func, constraints)) + << "RuntimeError: GPU constraint violated" + << func; + } + } + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyGPUCode", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.VerifyGPUCode") +.set_body_typed(VerifyGPUCode); + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 8e684e966770..03a36066bc08 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -22,6 +22,7 @@ * \brief Pass to check if memory accesses are legal. */ #include +#include #include #include #include @@ -174,31 +175,46 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } // namespace /// Interface of VerifyMemory pass -void VerifyMemory(const IRModule& mod) { - for (auto kv : mod->functions) { - if (auto* n = kv.second.as()) { - PrimFunc func = GetRef(n); - auto target = func->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerWarpMemory: Require the target attribute"; - - if (func->GetAttr( - tvm::attr::kCallingConv, - Integer(CallingConv::kDefault)) == CallingConv::kDefault) { - MemoryAccessVerifier v(func, target.value()->device_type); - v.Run(); - if (v.Failed()) { - LOG(FATAL) - << "ValueError: Direct host side access to device memory is detected." +bool VerifyMemory(const PrimFunc& func) { + auto target = func->GetAttr(tvm::attr::kTarget); + CHECK(target.defined()) + << "LowerWarpMemory: Require the target attribute"; + + if (func->GetAttr( + tvm::attr::kCallingConv, + Integer(CallingConv::kDefault)) == CallingConv::kDefault) { + MemoryAccessVerifier v(func, target.value()->device_type); + v.Run(); + return !v.Failed(); + } else { + return true; + } +} + +TVM_REGISTER_GLOBAL("tir.analysis.verify_memory") +.set_body_typed(VerifyMemory); + +namespace transform { + +Pass VerifyMemory() { + auto pass_func = [=](IRModule mod, PassContext ctx) { + for (auto kv : mod->functions) { + if (auto* n = kv.second.as()) { + auto func = GetRef(n); + CHECK(VerifyMemory(func)) + << "RuntimeError: Direct host side access to device memory is detected." << " Did you forget to bind?\n" << func; - } } } - } + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {}); } -TVM_REGISTER_GLOBAL("tir.analysis.verify_memory") +TVM_REGISTER_GLOBAL("tir.transform.VerifyMemory") .set_body_typed(VerifyMemory); + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc new file mode 100644 index 000000000000..97eaf2437523 --- /dev/null +++ b/src/tir/analysis/verify_ssa.cc @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * SSA related checks and pass. + * + * SSA requires each varaible to be only defined once. + * \file verify_ssa.cc + */ +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +class IRVerifySSA final : public StmtExprVisitor { + public: + bool is_ssa{true}; + + void VisitExpr(const PrimExpr& n) final { + if (!is_ssa) return; + StmtExprVisitor::VisitExpr(n); + } + void VisitStmt(const Stmt& n) final { + if (!is_ssa) return; + StmtExprVisitor::VisitStmt(n); + } + void VisitExpr_(const LetNode* op) final { + MarkDef(op->var.get()); + StmtExprVisitor::VisitExpr_(op); + } + void VisitStmt_(const LetStmtNode* op) final { + MarkDef(op->var.get()); + StmtExprVisitor::VisitStmt_(op); + } + void VisitStmt_(const ForNode* op) final { + MarkDef(op->loop_var.get()); + StmtExprVisitor::VisitStmt_(op); + } + void VisitStmt_(const AllocateNode* op) final { + MarkDef(op->buffer_var.get()); + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const VarNode* node) final { + if (match_scope_) { + MarkDef(node, true); + } + } + + void Run(const PrimFunc& func) { + for (auto param : func->params) { + MarkDef(param.get()); + } + + for (auto kv : func->buffer_map) { + this->DefineBuffer(kv.second); + } + this->VisitStmt(func->body); + } + + void DefineBuffer(const Buffer& buffer) { + match_scope_ = true; + this->VisitExpr(buffer->data); + for (size_t i = 0; i < buffer->shape.size(); ++i) { + this->VisitExpr(buffer->shape[i]); + } + + if (buffer->strides.defined()) { + for (size_t i = 0; i < buffer->strides.size(); ++i) { + this->VisitExpr(buffer->strides[i]); + } + } + this->VisitExpr(buffer->elem_offset); + + match_scope_ = false; + } + + private: + void MarkDef(const VarNode* v, bool allow_dup = false) { + if (defined_.count(v) != 0) { + if (!allow_dup) { + is_ssa = false; return; + } + } else { + defined_[v] = 1; + } + } + // whether we are in match scope, where a var can occur multiple times. + bool match_scope_{false}; + std::unordered_map defined_; +}; + + +bool VerifySSA(const PrimFunc& func) { + IRVerifySSA visitor; + visitor.Run(func); + return visitor.is_ssa; +} + +TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa") +.set_body_typed(VerifySSA); + + +namespace transform { + +Pass VerifySSA() { + auto pass_func = [=](IRModule mod, PassContext ctx) { + for (auto kv : mod->functions) { + if (auto* n = kv.second.as()) { + auto func = GetRef(n); + CHECK(VerifySSA(func)) + << "RuntimeError: IR is not in SSA form" + << func; + } + } + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifySSA", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.VerifySSA") +.set_body_typed(VerifySSA); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index a7bc822fdd30..0f1c572fd0a4 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -25,7 +25,6 @@ #include #include #include -#include #include #include diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index e8c850a4831a..a36d81f8f306 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -27,7 +27,7 @@ #include #include #include -#include "../pass/ir_util.h" + #include "../../support/str_escape.h" namespace tvm { @@ -362,9 +362,11 @@ Array CommReducerNode::operator()(Array a, Array b value_map.Set(lhs[i], a[i]); value_map.Set(rhs[i], b[i]); } - return UpdateArray(result, [&value_map] (const PrimExpr& e) { + auto ret = this->result; + ret.MutateByApply([&value_map] (const PrimExpr& e) { return Substitute(e, value_map); }); + return ret; } TVM_REGISTER_GLOBAL("tir.CommReducer") diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index f8e82ea787be..cc61e7e6abae 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -22,8 +22,8 @@ */ #include #include -#include -#include "../pass/ir_util.h" +#include + namespace tvm { namespace tir { diff --git a/src/tir/pass/hoist_if_then_else.cc b/src/tir/pass/hoist_if_then_else.cc index 3fa8687e4e7f..1e6f6c6ed3a0 100644 --- a/src/tir/pass/hoist_if_then_else.cc +++ b/src/tir/pass/hoist_if_then_else.cc @@ -20,6 +20,7 @@ /*! * \file hoist_if_then_else.cc */ +#include #include #include #include @@ -415,5 +416,9 @@ Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoist().VisitAndMutate(stmt); } + +TVM_REGISTER_GLOBAL("testing.HoistIfThenElse") +.set_body_typed(HoistIfThenElse); + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/ir_util.cc b/src/tir/pass/ir_util.cc deleted file mode 100644 index 7223c5b1c9e6..000000000000 --- a/src/tir/pass/ir_util.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ir_util.cc - * \brief Helper functions to construct and compose IR nodes. - */ -#include "ir_util.h" - -namespace tvm { -namespace tir { - -Stmt MergeNest(const std::vector& nest, Stmt body) { - // use reverse iteration - for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { - Stmt s = *ri; - if (const auto* for_ = s.as()) { - auto n = make_object(*for_); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (const auto* let = s.as()) { - auto n = make_object(*let); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (const auto* attr = s.as()) { - auto n = make_object(*attr); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (const auto* ite = s.as()) { - auto n = make_object(*ite); - CHECK(is_no_op(n->then_case)); - CHECK(!n->else_case.defined()); - n->then_case = body; - body = Stmt(n); - } else if (const auto* seq = s.as()) { - auto n = make_object(*seq); - CHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1])); - n->seq.Set(n->size() - 1, body); - body = Stmt(n); - } else if (const auto* assert_ = s.as()) { - auto n = make_object(*assert_); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (const auto* alloc = s.as()) { - auto n = make_object(*alloc); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else { - LOG(FATAL) << "not supported nest type"; - } - } - return body; -} - -Stmt MergeNest(const std::vector >& nest, Stmt body) { - for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { - body = MergeNest(*ri, body); - } - return body; -} - -} // namespace tir -} // namespace tvm diff --git a/src/tir/pass/arg_binder.cc b/src/tir/transforms/arg_binder.cc similarity index 99% rename from src/tir/pass/arg_binder.cc rename to src/tir/transforms/arg_binder.cc index 2f030470acc1..a68e4ee0db84 100644 --- a/src/tir/pass/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -22,7 +22,6 @@ * \brief Helper utility to match and bind arguments. */ #include -#include #include #include "ir_util.h" #include "arg_binder.h" diff --git a/src/tir/pass/arg_binder.h b/src/tir/transforms/arg_binder.h similarity index 97% rename from src/tir/pass/arg_binder.h rename to src/tir/transforms/arg_binder.h index 0ff51e8c98f1..1769950b8979 100644 --- a/src/tir/pass/arg_binder.h +++ b/src/tir/transforms/arg_binder.h @@ -21,8 +21,8 @@ * \file arg_binder.h * \brief Helper utility to match and bind arguments. */ -#ifndef TVM_TIR_PASS_ARG_BINDER_H_ -#define TVM_TIR_PASS_ARG_BINDER_H_ +#ifndef TVM_TIR_TRANSFORMS_ARG_BINDER_H_ +#define TVM_TIR_TRANSFORMS_ARG_BINDER_H_ #include #include @@ -160,4 +160,4 @@ class ArgBinder { }; } // namespace tir } // namespace tvm -#endif // TVM_TIR_PASS_ARG_BINDER_H_ +#endif // TVM_TIR_TRANSFORMS_ARG_BINDER_H_ diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index f770bc76941e..4b1c0094b194 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -215,10 +215,6 @@ Stmt InstrumentBoundCheckers(Stmt stmt) { return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt)); } - -TVM_REGISTER_GLOBAL("ir_pass.InstrumentBoundCheckers") -.set_body_typed(InstrumentBoundCheckers); - namespace transform { Pass InstrumentBoundCheckers() { diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index f8e14a2a8fb3..c17d66562ef1 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -30,7 +30,6 @@ #include #include -#include #include diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index fc20285a1a22..174564f26245 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -26,8 +26,8 @@ #include #include #include -#include "../pass/ir_util.h" -#include "../pass/storage_access.h" +#include "ir_util.h" +#include "storage_access.h" namespace tvm { namespace tir { @@ -678,9 +678,6 @@ Stmt CoProcSync(Stmt stmt) { return CoProcSyncInserter().Insert(std::move(stmt)); } -TVM_REGISTER_GLOBAL("ir_pass.CoProcSync") -.set_body_typed(CoProcSync); - namespace transform { Pass CoProcSync() { diff --git a/src/tir/pass/ffi_api.cc b/src/tir/transforms/decorate_device_scope.cc similarity index 54% rename from src/tir/pass/ffi_api.cc rename to src/tir/transforms/decorate_device_scope.cc index 95da62a502b0..7ff2e3f7d17e 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/transforms/decorate_device_scope.cc @@ -18,39 +18,38 @@ */ /*! - * Exposure of pass functions. - * \file ffi_api.cc + * \file decorate_device_scope.cc */ -#include -#include -#include -#include -#include -#include #include +#include +#include +#include namespace tvm { namespace tir { - - -TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var()); - }); - - -// make from two arguments -#define REGISTER_PASS(PassName) \ - TVM_REGISTER_GLOBAL("ir_pass."#PassName) \ - .set_body_typed(PassName); \ - - -REGISTER_PASS(ConvertSSA); -REGISTER_PASS(VerifySSA); -REGISTER_PASS(VerifyGPUCode); -REGISTER_PASS(DecorateDeviceScope); -REGISTER_PASS(VerifyCompactBuffer); -REGISTER_PASS(HoistIfThenElse); +Stmt DecorateDeviceScope(Stmt&& stmt) { + Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)), + tir::attr::device_scope, + 0, + stmt); + return body; +} + +namespace transform { + +Pass DecorateDeviceScope() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = DecorateDeviceScope(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.DecorateDeviceScope", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.DecorateDeviceScope") +.set_body_typed(DecorateDeviceScope); + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index d409ffc4a15d..86bbefc22830 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -200,8 +200,6 @@ Stmt InjectCopyIntrin(Stmt stmt, return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt)); } -TVM_REGISTER_GLOBAL("ir_pass.InjectCopyIntrin") -.set_body_typed(InjectCopyIntrin); namespace transform { diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index e9422fa2ff3e..4e5d08c69636 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -22,11 +22,10 @@ * \file inject_double_buffer.cc */ #include -#include #include #include #include -#include "../pass/ir_util.h" +#include "ir_util.h" #include "../../arith/compute_expr.h" namespace tvm { @@ -276,9 +275,6 @@ Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) { return DoubleBufferInjector(split_loop).Inject(stmt); } -TVM_REGISTER_GLOBAL("ir_pass.InjectDoubleBuffer") -.set_body_typed(InjectDoubleBuffer); - namespace transform { diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 24747a45600c..01fb6fe0bda8 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -24,8 +24,8 @@ #include #include #include -#include #include +#include "ir_util.h" #include "../../arith/compute_expr.h" namespace tvm { @@ -502,9 +502,6 @@ Stmt InjectVirtualThread(Stmt stmt) { return ConvertSSA(std::move(stmt)); } -TVM_REGISTER_GLOBAL("ir_pass.InjectVirtualThread") -.set_body_typed(InjectVirtualThread); - namespace transform { Pass InjectVirtualThread() { diff --git a/src/tir/pass/ssa.cc b/src/tir/transforms/ir_util.cc similarity index 72% rename from src/tir/pass/ssa.cc rename to src/tir/transforms/ir_util.cc index daef32c01bdb..9ff3fca77277 100644 --- a/src/tir/pass/ssa.cc +++ b/src/tir/transforms/ir_util.cc @@ -18,60 +18,71 @@ */ /*! - * SSA related checks and pass. - * - * SSA requires each varaible to be only defined once. - * \file ssa.cc + * \file ir_util.cc + * \brief Helper functions to construct and compose IR nodes. */ -#include #include -#include +#include #include #include -#include +#include "ir_util.h" namespace tvm { namespace tir { -namespace { -class IRVerifySSA final : public StmtExprVisitor { - public: - bool is_ssa{true}; - - void VisitExpr(const PrimExpr& n) final { - if (!is_ssa) return; - StmtExprVisitor::VisitExpr(n); - } - void VisitStmt(const Stmt& n) final { - if (!is_ssa) return; - StmtExprVisitor::VisitStmt(n); - } - void VisitExpr_(const LetNode* op) final { - MarkDef(op->var.get()); - StmtExprVisitor::VisitExpr_(op); - } - void VisitStmt_(const LetStmtNode* op) final { - MarkDef(op->var.get()); - StmtExprVisitor::VisitStmt_(op); - } - void VisitStmt_(const ForNode* op) final { - MarkDef(op->loop_var.get()); - StmtExprVisitor::VisitStmt_(op); - } - void VisitStmt_(const AllocateNode* op) final { - MarkDef(op->buffer_var.get()); - StmtExprVisitor::VisitStmt_(op); - } - private: - void MarkDef(const VarNode* v) { - if (defined_.count(v) != 0) { - is_ssa = false; return; +Stmt MergeNest(const std::vector& nest, Stmt body) { + // use reverse iteration + for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { + Stmt s = *ri; + if (const auto* for_ = s.as()) { + auto n = make_object(*for_); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (const auto* let = s.as()) { + auto n = make_object(*let); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (const auto* attr = s.as()) { + auto n = make_object(*attr); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (const auto* ite = s.as()) { + auto n = make_object(*ite); + CHECK(is_no_op(n->then_case)); + CHECK(!n->else_case.defined()); + n->then_case = body; + body = Stmt(n); + } else if (const auto* seq = s.as()) { + auto n = make_object(*seq); + CHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1])); + n->seq.Set(n->size() - 1, body); + body = Stmt(n); + } else if (const auto* assert_ = s.as()) { + auto n = make_object(*assert_); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (const auto* alloc = s.as()) { + auto n = make_object(*alloc); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); } else { - defined_[v] = 1; + LOG(FATAL) << "not supported nest type"; } } - std::unordered_map defined_; -}; + return body; +} + +Stmt MergeNest(const std::vector>& nest, Stmt body) { + for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { + body = MergeNest(*ri, body); + } + return body; +} class IRConvertSSA final : public StmtExprMutator { @@ -195,14 +206,6 @@ class IRConvertSSA final : public StmtExprMutator { std::unordered_set defined_; }; -} // namespace - -bool VerifySSA(const Stmt& ir) { - IRVerifySSA visitor; - visitor(ir); - return visitor.is_ssa; -} - Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } diff --git a/src/tir/pass/ir_util.h b/src/tir/transforms/ir_util.h similarity index 95% rename from src/tir/pass/ir_util.h rename to src/tir/transforms/ir_util.h index a167433dd112..18f79773d5f4 100644 --- a/src/tir/pass/ir_util.h +++ b/src/tir/transforms/ir_util.h @@ -21,8 +21,8 @@ * \file ir_util.h * \brief Helper functions to construct and compose IR nodes. */ -#ifndef TVM_TIR_PASS_IR_UTIL_H_ -#define TVM_TIR_PASS_IR_UTIL_H_ +#ifndef TVM_TIR_TRANSFORMS_IR_UTIL_H_ +#define TVM_TIR_TRANSFORMS_IR_UTIL_H_ #include #include @@ -174,6 +174,14 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) { return align; } + +/*! + * \brief Convert a IR node to be SSA form. + * \param stmt The source statement to be converted. + * \return The converted form. + */ +Stmt ConvertSSA(Stmt stmt); + } // namespace tir } // namespace tvm -#endif // TVM_TIR_PASS_IR_UTIL_H_ +#endif // TVM_TIR_TRANSFORMS_IR_UTIL_H_ diff --git a/src/tir/transforms/lift_attr_scope.cc b/src/tir/transforms/lift_attr_scope.cc index a1d922394df1..86b8cde2524c 100644 --- a/src/tir/transforms/lift_attr_scope.cc +++ b/src/tir/transforms/lift_attr_scope.cc @@ -26,7 +26,7 @@ #include #include #include -#include "../pass/ir_util.h" +#include "ir_util.h" namespace tvm { namespace tir { @@ -192,8 +192,6 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key) { return AttrScopeLifter(attr_key).Lift(std::move(stmt)); } -TVM_REGISTER_GLOBAL("ir_pass.LiftAttrScope") -.set_body_typed(LiftAttrScope); namespace transform { diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index dbed5f2abd86..dbceb37407f5 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -22,12 +22,13 @@ */ #include #include -#include #include #include #include +#include #include #include +#include "ir_util.h" #include "../../arith/interval_set.h" #include "../../runtime/thread_storage_scope.h" @@ -613,9 +614,6 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop) { } -TVM_REGISTER_GLOBAL("ir_pass.LoopPartition") -.set_body_typed(LoopPartition); - namespace transform { Pass LoopPartition(bool split_const_loop) { diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index a77d529e7764..dac426d8c273 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -27,7 +27,7 @@ #include #include #include -#include "../pass/ir_util.h" +#include "ir_util.h" #include "../../runtime/thread_storage_scope.h" namespace tvm { @@ -143,9 +143,6 @@ Stmt LowerStorageAccessInfo(Stmt stmt) { return StorageAccessInfoLower()(std::move(stmt)); } -TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccessInfo") -.set_body_typed(LowerStorageAccessInfo); - namespace transform { Pass LowerDeviceStorageAccessInfo() { diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 6ae638f33474..a909d4c6b83c 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -22,7 +22,6 @@ * \file lower_intrin.cc */ #include -#include #include #include diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 467e220d11ea..9cb817d04b6d 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -30,7 +30,7 @@ #include -#include "../pass/ir_util.h" +#include "ir_util.h" #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 76cfc434966d..ee6c44d21313 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -24,12 +24,11 @@ #include #include #include -#include #include #include -#include "../pass/ir_util.h" +#include "ir_util.h" namespace tvm { namespace tir { diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 96a901f2131f..516b96cd9c15 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -30,10 +30,10 @@ #include #include +#include #include #include #include -#include #include @@ -260,7 +260,7 @@ class WarpAccessRewriter : protected StmtExprMutator { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); // invariance: local index must do not contain warp id - CHECK(!ExprUseVar(local_index, {warp_index_.get()})) + CHECK(!ExprUseVar(local_index, warp_index_)) << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index << " local_index=" << local_index; PrimExpr load_value = LoadNode::make( diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 7980a9d7238f..4e5ca2ddf40d 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -20,7 +20,6 @@ /*! * \file make_packed_api.cc Lower PrimFunc to use the packed function API. */ -#include #include #include #include @@ -35,8 +34,8 @@ #include #include -#include "../pass/ir_util.h" -#include "../pass/arg_binder.h" +#include "ir_util.h" +#include "arg_binder.h" namespace tvm { namespace tir { diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 4aeaafda48ba..4cf5ccdd081c 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -22,7 +22,6 @@ * \brief narrow the datatype of indexing vars */ -#include #include #include #include @@ -395,10 +394,6 @@ Stmt NarrowDataType(Stmt stmt, int target_bits) { return DataTypeRewriter(target_bits)(stmt); } -TVM_REGISTER_GLOBAL("ir_pass.NarrowDataType") -.set_body_typed(NarrowDataType); - - namespace transform { Pass NarrowDataType(int target_bits) { diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index 44c974fc0fb0..ceaf27b816ff 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include #include @@ -151,9 +151,6 @@ Stmt RemoveNoOp(Stmt stmt) { return NoOpRemover()(std::move(stmt)); } -TVM_REGISTER_GLOBAL("ir_pass.RemoveNoOp") -.set_body_typed(RemoveNoOp); - namespace transform { Pass RemoveNoOp() { diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index 386b4cc66ed6..6052cbf32b5d 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -133,9 +133,6 @@ Stmt RewriteUnsafeSelect(Stmt stmt) { return UnsafeSelectRewriter()(std::move(stmt)); } -TVM_REGISTER_GLOBAL("ir_pass.RewriteUnsafeSelect") -.set_body_typed(RewriteUnsafeSelect); - namespace transform { Pass RewriteUnsafeSelect() { diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 1e4fd73f6d0c..752939e625ae 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -23,7 +23,6 @@ */ #include #include -#include #include #include #include diff --git a/src/tir/transforms/skip_assert.cc b/src/tir/transforms/skip_assert.cc index 2857639f2e78..4511838efe57 100644 --- a/src/tir/transforms/skip_assert.cc +++ b/src/tir/transforms/skip_assert.cc @@ -18,7 +18,6 @@ */ #include -#include #include #include #include diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 927536b5938e..44f032ffa444 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -22,9 +22,10 @@ * \brief Split device function from host. */ #include +#include #include #include -#include +#include #include #include #include diff --git a/src/tir/pass/storage_access.cc b/src/tir/transforms/storage_access.cc similarity index 99% rename from src/tir/pass/storage_access.cc rename to src/tir/transforms/storage_access.cc index f6bba486c785..1f28e138a1ef 100644 --- a/src/tir/pass/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -20,7 +20,6 @@ /*! * \file storage_access.cc */ -#include #include #include #include diff --git a/src/tir/pass/storage_access.h b/src/tir/transforms/storage_access.h similarity index 96% rename from src/tir/pass/storage_access.h rename to src/tir/transforms/storage_access.h index d3614b8fff4e..12e76bd08732 100644 --- a/src/tir/pass/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -21,12 +21,12 @@ * \file storage_access.h * \brief Common data structure for storage access analysis. */ -#ifndef TVM_TIR_PASS_STORAGE_ACCESS_H_ -#define TVM_TIR_PASS_STORAGE_ACCESS_H_ +#ifndef TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ +#define TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ -#include #include -#include +#include +#include #include #include #include @@ -150,4 +150,4 @@ class StorageAccessVisitor : public StmtExprVisitor { } // namespace tir } // namespace tvm -#endif // TVM_TIR_PASS_STORAGE_ACCESS_H_ +#endif // TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index e5b2ad89ace9..96686950b497 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -35,8 +35,8 @@ #include #include #include -#include "../pass/ir_util.h" -#include "../pass/arg_binder.h" +#include "ir_util.h" +#include "arg_binder.h" #include "../../arith/compute_expr.h" #include "../../arith/ir_visitor_with_analyzer.h" #include "../../runtime/thread_storage_scope.h" diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index f960306f2ee8..ca2b5a9b45aa 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -25,7 +25,6 @@ #include #include #include -#include #include #include #include @@ -33,7 +32,7 @@ #include #include #include -#include "../pass/ir_util.h" +#include "ir_util.h" #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" @@ -937,6 +936,7 @@ class StoragePlanRewriter : public StmtExprMutator { arith::Analyzer analyzer_; }; + // Turn alloc into vector alloc // if all its access is the same vector type. class VectorAllocRewriter : public StmtExprMutator { @@ -995,6 +995,11 @@ class VectorAllocRewriter : public StmtExprMutator { arith::Analyzer analyzer_; }; +Stmt StorageRewrite(Stmt stmt) { + stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true); + return VectorAllocRewriter()(std::move(stmt)); +} + PrimFunc PointerValueTypeRewrite(PrimFunc f) { auto* n = f.CopyOnWrite(); @@ -1037,13 +1042,6 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) { return f; } -Stmt StorageRewrite(Stmt stmt) { - stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true); - return VectorAllocRewriter()(std::move(stmt)); -} - -TVM_REGISTER_GLOBAL("ir_pass.StorageRewrite") -.set_body_typed(StorageRewrite); namespace transform { @@ -1060,6 +1058,17 @@ Pass StorageRewrite() { TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite") .set_body_typed(StorageRewrite); + +Pass PointerValueTypeRewrite() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + return PointerValueTypeRewrite(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.PointerValueTypeRewrite", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.PointerValueTypeRewrite") +.set_body_typed(PointerValueTypeRewrite); + } // namespace transform } // namespace tir diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index 1ece078e6c3c..9924dd2bb084 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -22,7 +22,6 @@ * \file tensorcore_fragment.cc */ #include -#include #include #include #include @@ -30,8 +29,8 @@ #include #include -#include "../pass/storage_access.h" -#include "../pass/ir_util.h" +#include "storage_access.h" +#include "ir_util.h" #include "../../runtime/thread_storage_scope.h" namespace tvm { diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index f464af655a15..a32fd647dbb6 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -21,7 +21,6 @@ * \file thread_storage_sync.cc */ #include -#include #include #include #include @@ -31,8 +30,8 @@ #include #include -#include "../pass/ir_util.h" -#include "../pass/storage_access.h" +#include "ir_util.h" +#include "storage_access.h" #include "../../runtime/thread_storage_scope.h" namespace tvm { diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 5eb244d0677a..4fc69a3d892a 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -27,11 +27,11 @@ #include #include #include -#include #include #include #include #include +#include "ir_util.h" #include "../../arith/compute_expr.h" namespace tvm { @@ -204,9 +204,6 @@ Stmt UnrollLoop(Stmt stmt, } } -TVM_REGISTER_GLOBAL("ir_pass.UnrollLoop") -.set_body_typed(UnrollLoop); - namespace transform { Pass UnrollLoop(int auto_max_step, diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 22995733b31e..e155c709c746 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -25,7 +25,6 @@ #include #include #include -#include #include #include #include @@ -553,12 +552,6 @@ Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); } -TVM_REGISTER_GLOBAL("ir_pass.VectorizeLoop") -.set_body_typed(VectorizeLoop); - -TVM_REGISTER_GLOBAL("ir_pass.SkipVectorize") -.set_body_typed(SkipVectorize); - namespace transform { // TODO(tvm-team): Make it as a target property. diff --git a/tests/cpp/ir_ssa_test.cc b/tests/cpp/ir_ssa_test.cc deleted file mode 100644 index 56f178dbcf4e..000000000000 --- a/tests/cpp/ir_ssa_test.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include - - -TEST(IRSSA, Convert) { - using namespace tvm; - using namespace tvm::tir; - Var x("x"), y; - PrimExpr let = LetNode::make(x, 1, x + 1); - - auto z = EvaluateNode::make(let + let); - CHECK(!tir::VerifySSA(z)); - auto z_ssa = tir::ConvertSSA(z); - CHECK(tir::VerifySSA(z_ssa)); -} - -TEST(IRSSA, Basic) { - using namespace tvm::tir; - using namespace tvm; - Var x("x"), y; - auto z = EvaluateNode::make(x + y); - CHECK(tir::VerifySSA(z)); -} - -int main(int argc, char ** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/simple_passes_test.cc b/tests/cpp/simple_passes_test.cc index a3c6b07ddc8a..be4c74674809 100644 --- a/tests/cpp/simple_passes_test.cc +++ b/tests/cpp/simple_passes_test.cc @@ -19,8 +19,8 @@ #include #include -#include #include +#include TEST(SimplePasses, HasSideEffect) { using namespace tvm; diff --git a/tests/python/unittest/test_tir_pass_rewrite_for_tensor_core.py b/tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py similarity index 100% rename from tests/python/unittest/test_tir_pass_rewrite_for_tensor_core.py rename to tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py diff --git a/tests/python/unittest/test_tir_pass_verify_gpu_code.py b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py similarity index 99% rename from tests/python/unittest/test_tir_pass_verify_gpu_code.py rename to tests/python/unittest/test_tir_analysis_verify_gpu_code.py index 091a3749dc74..ea17d893dca7 100644 --- a/tests/python/unittest/test_tir_pass_verify_gpu_code.py +++ b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py @@ -20,10 +20,11 @@ def get_verify_pass(valid, **kwargs): def _fverify(f, *_): - valid[0] = tvm.tir.ir_pass.VerifyGPUCode(f.body, kwargs) + valid[0] = tvm.tir.analysis.verify_gpu_code(f, kwargs) return f return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0) + def test_shared_memory(): def check_shared_memory(dtype): N = 1024 diff --git a/tests/python/unittest/test_tir_analysis_verify_memory.py b/tests/python/unittest/test_tir_analysis_verify_memory.py index b0de91b435ea..386fceb150e3 100644 --- a/tests/python/unittest/test_tir_analysis_verify_memory.py +++ b/tests/python/unittest/test_tir_analysis_verify_memory.py @@ -43,7 +43,7 @@ def test_verify_memory_all_bind(): for dev_type in gpu_devices + other_devices: binded_mod = tvm.tir.transform.Apply( lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - tvm.tir.analysis.verify_memory(binded_mod) + tvm.tir.transform.VerifyMemory()(binded_mod) @@ -63,13 +63,13 @@ def test_verify_memory_not_bind(): for dev_type in gpu_devices: binded_mod = tvm.tir.transform.Apply( lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - with pytest.raises(ValueError): - tvm.tir.analysis.verify_memory(binded_mod) + with pytest.raises(RuntimeError): + tvm.tir.transform.VerifyMemory()(binded_mod) for dev_type in other_devices: binded_mod = tvm.tir.transform.Apply( lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - tvm.tir.analysis.verify_memory(binded_mod) + tvm.tir.transform.VerifyMemory()(binded_mod) # Computations are partially bound. @@ -93,13 +93,13 @@ def test_verify_memory_partially_bind(): for dev_type in gpu_devices: binded_mod = tvm.tir.transform.Apply( lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - with pytest.raises(ValueError): - tvm.tir.analysis.verify_memory(binded_mod) + with pytest.raises(RuntimeError): + tvm.tir.transform.VerifyMemory()(binded_mod) for dev_type in other_devices: binded_mod = tvm.tir.transform.Apply( lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - tvm.tir.analysis.verify_memory(binded_mod) + tvm.tir.transform.VerifyMemory()(binded_mod) diff --git a/tests/python/unittest/test_tir_pass_basic.py b/tests/python/unittest/test_tir_analysis_verify_ssa.py similarity index 64% rename from tests/python/unittest/test_tir_pass_basic.py rename to tests/python/unittest/test_tir_analysis_verify_ssa.py index 23982873075f..8a15c3628074 100644 --- a/tests/python/unittest/test_tir_pass_basic.py +++ b/tests/python/unittest/test_tir_analysis_verify_ssa.py @@ -17,31 +17,16 @@ import tvm from tvm import te - - def test_verify_ssa(): x = te.var('x') y = te.var() z = tvm.tir.Evaluate(x + y) - assert(tvm.tir.ir_pass.VerifySSA(z)) - - -def test_convert_ssa(): - x = te.var('x') - y = te.var() - let1 = tvm.tir.Let(x, 1, x + 1) - let2 = tvm.tir.Let(x, 1, x + y) - z = tvm.tir.Evaluate(let1 + let2) - assert(not tvm.tir.ir_pass.VerifySSA(z)) - z_ssa = tvm.tir.ir_pass.ConvertSSA(z) - assert(tvm.tir.ir_pass.VerifySSA(z_ssa)) + assert(tvm.tir.analysis.verify_ssa( + tvm.tir.PrimFunc([x, y],z))) - -def test_expr_use_var(): - x = te.var('x') - assert(tvm.tir.ir_pass.ExprUseVar(x+1, x)) - assert(not tvm.tir.ir_pass.ExprUseVar(1+10, x)) + assert(not tvm.tir.analysis.verify_ssa( + tvm.tir.PrimFunc([x, y], tvm.tir.LetStmt(x, 1, z)))) if __name__ == "__main__": - test_expr_use_var() + test_verify_ssa() diff --git a/tests/python/unittest/test_tir_pass_attrs_hash_equal.py b/tests/python/unittest/test_tir_pass_attrs_hash_equal.py deleted file mode 100644 index 9a115be74559..000000000000 --- a/tests/python/unittest/test_tir_pass_attrs_hash_equal.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te - -def test_attrs_equal(): - x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) - y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) - z = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4,1)) - assert tvm.ir.structural_equal(x, y) - assert not tvm.ir.structural_equal(x, z) - - dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) - assert not tvm.ir.structural_equal(dattr, x) - dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) - assert tvm.ir.structural_equal(dattr, dattr2) - - assert tvm.ir.structural_equal({"x": x}, {"x": y}) - # array related checks - assert tvm.ir.structural_equal({"x": [x, x]}, {"x": [y, x]}) - assert not tvm.ir.structural_equal({"x": [x, 1]}, {"x": [y, 2]}) - - n = te.var("n") - assert tvm.ir.structural_equal({"x": n+1}, {"x": n+1}) - - - - - -def test_attrs_hash(): - fhash = tvm.ir.structural_hash - x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) - y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) - assert fhash({"x": x}) == fhash({"x": y}) - assert fhash({"x": x}) != fhash({"x": [y, 1]}) - assert fhash({"x": [x, 1]}) == fhash({"x": [y, 1]}) - assert fhash({"x": [x, 2]}) == fhash({"x": [y, 2]}) - - -if __name__ == "__main__": - test_attrs_equal() - test_attrs_hash() diff --git a/tests/python/unittest/test_tir_pass_decorate_device_scope.py b/tests/python/unittest/test_tir_pass_decorate_device_scope.py deleted file mode 100644 index 9c58431158b9..000000000000 --- a/tests/python/unittest/test_tir_pass_decorate_device_scope.py +++ /dev/null @@ -1,42 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te - -def test_decorate_device(): - m = te.size_var('m') - l = te.size_var('l') - A = te.placeholder((m, l), name='A') - - A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1') - A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') - - s = te.create_schedule(A2.op) - xo, xi = s[A2].split(A2.op.axis[0], factor=8) - s[A1].compute_at(s[A2], xo) - s[A1].set_scope("shared") - - bounds = tvm.te.schedule.InferBound(s) - stmt1 = tvm.te.schedule.ScheduleOps(s, bounds) - stmt2 = tvm.tir.ir_pass.DecorateDeviceScope(stmt1) - assert isinstance(stmt2, tvm.tir.AttrStmt) - assert stmt2.attr_key == "device_scope" - assert stmt1 == stmt2.body - -if __name__ == "__main__": - test_decorate_device() - diff --git a/tests/python/unittest/test_tir_pass_hoist_if.py b/tests/python/unittest/test_tir_pass_hoist_if.py index e7e3a250a770..346239d302cf 100644 --- a/tests/python/unittest/test_tir_pass_hoist_if.py +++ b/tests/python/unittest/test_tir_pass_hoist_if.py @@ -67,7 +67,7 @@ def test_basic(): ib.emit(tvm.tir.Evaluate(n)) stmt = ib.get() - new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt) + new_stmt = tvm.testing.HoistIfThenElse(stmt) expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), ('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j')), ('For', 'i'): (('IfThenElse', ('i',)),)} @@ -86,7 +86,7 @@ def test_no_else(): ib.emit(tvm.tir.Evaluate(m)) stmt = ib.get() - new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt) + new_stmt = tvm.testing.HoistIfThenElse(stmt) expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),)} @@ -113,7 +113,7 @@ def test_attr_stmt(): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.0 stmt = ib.get() - new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt) + new_stmt = tvm.testing.HoistIfThenElse(stmt) expected_struct = {('For', 'k'): (None,), ('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k')), ('For', 'j'): (('IfThenElse', ('i', 'j')),), ('For', 'i'): (('For', 'j'),), ('AttrStmt', 'thread_extent', 64): (('For', 'i'),), @@ -137,7 +137,7 @@ def test_nested_for(): data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 stmt = ib.get() - new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt) + new_stmt = tvm.testing.HoistIfThenElse(stmt) expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('For', 'l'): (('IfThenElse', ('i', 'j')),), ('For', 'k'): (('For', 'l'),), ('For', 'j'): (None,), ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),)} @@ -170,7 +170,7 @@ def test_if_block(): data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6 stmt = ib.get() - new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt) + new_stmt = tvm.testing.HoistIfThenElse(stmt) expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('IfThenElse', ('j',)): (None, None), ('For', 'l'): (None,), ('For', 'k'): (None,), ('For', 'j'): (('For', 'j'),), ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),), diff --git a/tests/python/unittest/test_tir_pass_ir_transform.py b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py similarity index 100% rename from tests/python/unittest/test_tir_pass_ir_transform.py rename to tests/python/unittest/test_tir_stmt_functor_ir_transform.py diff --git a/python/tvm/tir/ir_pass.py b/tests/python/unittest/test_tir_transform_decorate_device_scope.py similarity index 62% rename from python/tvm/tir/ir_pass.py rename to tests/python/unittest/test_tir_transform_decorate_device_scope.py index 239b1fb98dd0..cf9ea9e00fe1 100644 --- a/python/tvm/tir/ir_pass.py +++ b/tests/python/unittest/test_tir_transform_decorate_device_scope.py @@ -14,15 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Namespace of IR pass functions. +import tvm +from tvm import te -This namespace is used for developers. While you do not see any declarations. -The functions are automatically exported from C++ side via PackedFunc. +def test_decorate_device(): + x = te.var("x") + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x))) -Each api is a PackedFunc that can be called in a positional argument manner. -You can read "include/tvm/tir/ir_pass.h" for the function signature and -"src/api/api_pass.cc" for the PackedFunc's body of these functions. -""" -import tvm._ffi + stmt = tvm.tir.transform.DecorateDeviceScope()(mod)["main"].body + assert stmt.attr_key == "device_scope" -tvm._ffi._init_api("tvm.ir_pass", __name__) +if __name__ == "__main__": + test_decorate_device() diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index 7854835a21d6..26e9438fddc5 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -31,12 +31,16 @@ def test_unroll_loop(): Aptr[j + 1] = Aptr[i] + 1 stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt)) + assert isinstance(stmt, tvm.tir.For) - ret = tvm.tir.ir_pass.UnrollLoop(stmt, 16, 8, 0, True) + + ret = tvm.tir.transform.UnrollLoop(16, 8, 0, True)(mod)["main"].body + assert not isinstance(ret, tvm.tir.For) - ret = tvm.tir.ir_pass.UnrollLoop(stmt, 15, 8, 0, True) + ret = tvm.tir.transform.UnrollLoop(15, 8, 0, True)(mod)["main"].body assert isinstance(ret, tvm.tir.For) - ret = tvm.tir.ir_pass.UnrollLoop(stmt, 16, 8, 0, False) + ret = tvm.tir.transform.UnrollLoop(16, 8, 0, False)(mod)["main"].body assert isinstance(ret, tvm.tir.For) assert ret.for_type == tvm.tir.For.Unrolled @@ -46,11 +50,9 @@ def test_unroll_loop(): wrapped = ib.get() wrapped = tvm.tir.SeqStmt([wrapped, stmt]) assert isinstance(ret, tvm.tir.For) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], wrapped)) ret = tvm.tir.transform.UnrollLoop(0, 8, 0, False)(mod)["main"].body - # ret = tvm.tir.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False) assert isinstance(ret[0], tvm.tir.For) assert ret[0].for_type == tvm.tir.For.Unrolled assert isinstance(ret[1], tvm.tir.For) @@ -72,8 +74,6 @@ def test_unroll_fake_loop(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt)) ret = tvm.tir.transform.UnrollLoop(8, 0, 1, False)(mod)["main"].body - - # ret = tvm.tir.ir_pass.UnrollLoop(stmt, 8, 0, 1, True) assert isinstance(ret[0], tvm.tir.Store) def test_unroll_single_count_loops(): diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index 49eb088feb7a..41f4b453188b 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -25,7 +25,6 @@ #define TOPI_ELEMWISE_H_ #include -#include #include #include #include