From c22dd35981aab4ab9b5f492d2053acb1b18b2579 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 13 Apr 2020 17:19:19 -0700 Subject: [PATCH] [TIR] Refactor MakePackedAPI to target dependent stage. Previously MakePackedAPI was in the target independent stage, but never the less requires the device_type information that will be binded at a later target dependent stage. The previous implementation was due to the limitation of LoweredFunc which can not carry buffer_map info(so they have to be lowered right away). This is no longer the case after the unified IR refactor. This PR migrates MakePackedAPI to a target dependent stage and removes the un-necessary BindDevice pass. --- include/tvm/ir/transform.h | 9 +- include/tvm/runtime/device_api.h | 9 ++ include/tvm/tir/transform.h | 9 -- python/tvm/driver/build_module.py | 10 +- python/tvm/ir/transform.py | 27 ++++- python/tvm/testing.py | 3 +- python/tvm/tir/transform/transform.py | 42 +++---- src/driver/driver_api.cc | 5 +- src/ir/transform.cc | 13 ++ src/target/stackvm/codegen_stackvm.cc | 2 + src/tir/analysis/verify_memory.cc | 6 +- src/tir/transforms/bind_device_type.cc | 113 ------------------ src/tir/transforms/make_packed_api.cc | 25 ++-- .../python/unittest/test_runtime_extension.py | 4 +- .../unittest/test_target_codegen_llvm.py | 2 +- .../test_tir_analysis_verify_memory.py | 2 +- .../test_tir_pass_inject_double_buffer.py | 5 +- .../unittest/test_tir_pass_storage_flatten.py | 2 +- ...test_tir_transform_combine_context_call.py | 5 +- .../test_tir_transform_make_packed_api.py | 6 +- .../test_tir_transform_prim_func_pass.py | 2 +- 21 files changed, 114 insertions(+), 187 deletions(-) delete mode 100644 src/tir/transforms/bind_device_type.cc diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 83619023d485..4c55204547b9 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -352,12 +352,19 @@ class Sequential : public Pass { * * \return The created module pass. */ -Pass CreateModulePass( +TVM_DLL Pass CreateModulePass( const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, const Array& required); + +/*! + * \brief A special trace pass that prints the header and IR to LOG(INFO). + * \return The pass. + */ +TVM_DLL Pass PrintIR(std::string header); + } // namespace transform } // namespace tvm diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 470a1fefd856..f2ddc84e9f98 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -193,6 +193,15 @@ class TVM_DLL DeviceAPI { * \return The corresponding device API. */ static DeviceAPI* Get(TVMContext ctx, bool allow_missing = false); + + /*! + * \brief Whether a certian device type requires set device context + * before launching the kernel function. + * \param device_type The device type. + */ + static bool NeedSetDeviceContext(int device_type) { + return device_type != kDLCPU && device_type != kDLMicroDev; + } }; /*! \brief The device type bigger than this is RPC device */ diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 0c5b39b0f382..23c195563ac2 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -112,15 +112,6 @@ TVM_DLL Pass RemapThreadAxis(Map axis_map); */ TVM_DLL Pass LowerCustomDatatypes(); - -/*! - * \brief Bind the device type ofthe function to be - * the device_type specified in the target attribute. - * - * \return The pass. - */ -TVM_DLL Pass BindDeviceType(); - /*! * \brief Split the function into a host function and device functions. * diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index c0e990e8394d..a429d0775dae 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -200,7 +200,7 @@ def lower(sch, if cfg.restricted_func: f = f.with_attr("tir.noalias", True) mod = tvm.IRModule({name: f}) - return tvm.tir.transform.MakePackedAPI()(mod) + return mod def _build_for_device(input_mod, target, target_host): @@ -243,13 +243,13 @@ def _build_for_device(input_mod, target, target_host): tvm.tir.transform.ThreadSync("warp"), tvm.tir.transform.InferFragment(), tvm.tir.transform.LowerThreadAllreduce(), - tvm.tir.transform.BindDeviceType(), + tvm.tir.transform.MakePackedAPI(), tvm.tir.transform.SplitHostDevice()] - mod_mixed = tvm.ir.transform.Sequential(opt_mixed)(mod_mixed) + mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed) # device optimizations - opt_device = tvm.ir.transform.Sequential( + opt_device = tvm.transform.Sequential( [tvm.tir.transform.Filter( lambda f: "calling_conv" in f.attrs and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH), @@ -259,7 +259,7 @@ def _build_for_device(input_mod, target, target_host): mod_dev = opt_device(mod_mixed) # host optimizations - opt_host = tvm.ir.transform.Sequential( + opt_host = tvm.transform.Sequential( [tvm.tir.transform.Filter( lambda f: "calling_conv" not in f.attrs or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH), diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index cdb92576cd63..da74fb227a2e 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -22,13 +22,13 @@ import tvm._ffi -from tvm._ffi.runtime_ctypes import TVMContext -from tvm.runtime import Object, ndarray as _nd +import tvm.runtime +from tvm.runtime import ndarray as _nd from . import _ffi_transform_api @tvm._ffi.register_object("transform.PassInfo") -class PassInfo(Object): +class PassInfo(tvm.runtime.Object): """The class contains the meta data required by a pass. It is the container of information needed by running an optimization or analysis. This class can be extended by adding new members when more meta data is @@ -52,7 +52,7 @@ def __init__(self, opt_level, name, required=None): @tvm._ffi.register_object("transform.PassContext") -class PassContext(Object): +class PassContext(tvm.runtime.Object): """The basis where a Relay optimization/analysis runs on. Each pass context contains a number of auxiliary information that is used to help an optimization pass. Such information includes the error reporter @@ -79,7 +79,7 @@ def __init__(self, trace=None): if isinstance(fallback_device, str): fallback_device = _nd.context(fallback_device).device_type - elif isinstance(fallback_device, TVMContext): + elif isinstance(fallback_device, tvm.runtime.TVMContext): fallback_device = fallback_device.device_type if not isinstance(fallback_device, int): raise TypeError("fallback_device is expected to be the type of " + @@ -113,7 +113,7 @@ def current(): @tvm._ffi.register_object("transform.Pass") -class Pass(Object): +class Pass(tvm.runtime.Object): """The base class of all passes. All methods here are just simple wrappers that are implemented in the backend. They are defined for users to conveniently interact with the base class. @@ -327,3 +327,18 @@ def create_module_pass(pass_arg): if pass_func: return create_module_pass(pass_func) return create_module_pass + + +def PrintIR(header): + """A special trace pass that prints the header and IR. + + Parameters + ---------- + header : str + The header to be displayed along with the dump. + + Returns + -------- + The pass + """ + return _ffi_transform_api.PrintIR(header) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 1edb3b85a769..064c43891d2d 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -195,13 +195,14 @@ def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias): mod : IRModule The created IRModule. """ + assert num_unpacked_args == 0 f = tvm.tir.PrimFunc(args, stmt).with_attr( "global_symbol", tvm.runtime.String(name)) f = f.with_attr("tir.is_entry_func", True) if noalias: f = f.with_attr("tir.noalias", True) mod = tvm.IRModule({name: f}) - return tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod) + return mod tvm._ffi._init_api("testing", __name__) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 91321fbf7c81..9f64a93a4860 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -32,7 +32,7 @@ def Apply(ftransform): Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass """ # pylint: disable=unused-argument @@ -51,7 +51,7 @@ def Filter(fcond): Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass """ # pylint: disable=unused-argument @@ -67,7 +67,7 @@ def LowerCustomDatatypes(): Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass """ return _ffi_api.LowerCustomDatatypes() @@ -84,30 +84,18 @@ def MakePackedAPI(num_unpacked_params=0): Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass """ return _ffi_api.MakePackedAPI(num_unpacked_params) -def BindDeviceType(): - """Bind the device type of the function to be - the device_type specified in the target attribute. - - Returns - ------- - fpass : tvm.ir.transform.Pass - The result pass - """ - return _ffi_api.BindDeviceType() - - def SplitHostDevice(): """Split the function into a host function and device functions. Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass """ return _ffi_api.SplitHostDevice() @@ -118,7 +106,7 @@ def SkipAssert(): Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass """ return _ffi_api.SkipAssert() @@ -134,7 +122,7 @@ def ThreadSync(storage_scope): Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass """ return _ffi_api.ThreadSync(storage_scope) @@ -145,7 +133,7 @@ def LowerThreadAllreduce(): Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass """ return _ffi_api.LowerThreadAllreduce() @@ -156,7 +144,7 @@ def InferFragment(): Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass """ return _ffi_api.InferFragment() @@ -167,7 +155,7 @@ def LowerWarpMemory(): Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass """ return _ffi_api.LowerWarpMemory() @@ -178,7 +166,7 @@ def LowerTVMBuiltin(): Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass """ return _ffi_api.LowerTVMBuiltin() @@ -189,7 +177,7 @@ def LowerIntrin(): Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass """ return _ffi_api.LowerIntrin() @@ -200,7 +188,7 @@ def LowerDeviceStorageAccessInfo(): Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass Note @@ -215,7 +203,7 @@ def CombineContextCall(): Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass """ return _ffi_api.CombineContextCall() @@ -231,7 +219,7 @@ def NarrowDataType(target_bits): Returns ------- - fpass : tvm.ir.transform.Pass + fpass : tvm.transform.Pass The result pass Note diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index d7955a2ca620..f576c842b25c 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -216,8 +216,7 @@ IRModule lower(te::Schedule sch, if (config->restricted_func) { f = WithAttr(std::move(f), "tir.noalias", Integer(1)); } - auto mod = IRModule(Map({{GlobalVar(name), f}})); - return tir::transform::MakePackedAPI(0)(mod); + return IRModule(Map({{GlobalVar(name), f}})); } @@ -237,7 +236,7 @@ split_dev_host_funcs(IRModule mod_mixed, mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); mixed_pass_list.push_back(tir::transform::InferFragment()); mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); - mixed_pass_list.push_back(tir::transform::BindDeviceType()); + mixed_pass_list.push_back(tir::transform::MakePackedAPI(0)); mixed_pass_list.push_back(tir::transform::SplitHostDevice()); auto opt_mixed = transform::Sequential(mixed_pass_list); mod_mixed = opt_mixed(std::move(mod_mixed)); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index ef524c3d54be..0161cb377f0d 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -473,5 +473,18 @@ TVM_REGISTER_GLOBAL("transform.EnterPassContext") TVM_REGISTER_GLOBAL("transform.ExitPassContext") .set_body_typed(PassContext::Internal::ExitScope); + +Pass PrintIR(std::string header) { + auto pass_func =[header](IRModule mod, const PassContext& ctx) { + LOG(INFO) << "PrintIR(" << header << "):\n" + << mod; + return mod; + }; + return CreateModulePass(pass_func, 0, "PrintIR", {}); +} + +TVM_REGISTER_GLOBAL("transform.PrintIR") +.set_body_typed(PrintIR); + } // namespace transform } // namespace tvm diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index 661fdabd3c32..383aaf38cea3 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -58,6 +58,8 @@ StackVM::StructFieldKind MapFieldKind(int64_t kind) { } StackVM CodeGenStackVM::Compile(const PrimFunc& f) { + CHECK_EQ(f->buffer_map.size(), 0U) + << "Cannot codegen function with buffer_map, please lower them first"; for (size_t i = 0; i < f->params.size(); ++i) { Var v = f->params[i]; int vid = AllocVarID(v.get()); diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 9ff4f3d5b738..2a87b2e3e271 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -114,9 +114,11 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Check if the value of a Variable comes from function argument. bool IsFromFunctionArgs(const VarNode *var) const { const VarNode *V = var; - while (true) { - CHECK(V) << "Invalid Variable\n"; + for (auto kv : func_->buffer_map) { + if (V == kv.second->data.get()) return true; + } + while (true) { // Variable is from function args. Return true. if (V == func_->params[0].get()) return true; diff --git a/src/tir/transforms/bind_device_type.cc b/src/tir/transforms/bind_device_type.cc deleted file mode 100644 index a6db9f9c6da8..000000000000 --- a/src/tir/transforms/bind_device_type.cc +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file bind_device_type.cc - * \brief Bind the device type according to the target field. - */ -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace tir { - -class DeviceTypeBinder: public StmtExprMutator { - public: - explicit DeviceTypeBinder(int device_type) - : device_type_(device_type) {} - - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::device_context_type) { - if (const VarNode* var = op->value.as()) { - var_ = var; - PrimExpr value = make_const(op->value.dtype(), device_type_); - Stmt body = StmtExprMutator::VisitStmt_(op); - var_ = nullptr; - std::ostringstream os; - os << "device_type need to be " << device_type_; - return AssertStmtNode::make(op->value == value, tvm::tir::StringImmNode::make(os.str()), - body); - } - } - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const IfThenElseNode* op) final { - // eager simplify if guard. - Stmt res = StmtExprMutator::VisitStmt_(op); - op = res.as(); - if (is_zero(op->condition)) { - if (op->else_case.defined()) return op->else_case; - return EvaluateNode::make(0); - } - if (is_one(op->condition)) { - return op->then_case; - } - return res; - } - - PrimExpr VisitExpr_(const NENode* op) final { - // eager check NE for device check - PrimExpr res = StmtExprMutator::VisitExpr_(op); - op = res.as(); - if (tir::ExprDeepEqual()(op->a, op->b)) { - return make_const(op->dtype, false); - } - return res; - } - - PrimExpr VisitExpr_(const VarNode* op) final { - if (op == var_) { - return make_const(op->dtype, device_type_); - } else { - return GetRef(op); - } - } - - public: - const VarNode* var_{nullptr}; - int device_type_; -}; - -namespace transform { - -Pass BindDeviceType() { - auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - auto target = f->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "BindDeviceType: Require the target attribute"; - n->body = DeviceTypeBinder(target.value()->device_type)(std::move(n->body)); - return f; - }; - return CreatePrimFuncPass(pass_func, 0, "tir.BindDeviceType", {}); -} - -TVM_REGISTER_GLOBAL("tir.transform.BindDeviceType") -.set_body_typed(BindDeviceType); - -} // namespace transform -} // namespace tir -} // namespace tvm diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index dd4bd6642676..7980a9d7238f 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -50,6 +51,12 @@ PrimFunc MakePackedAPI(PrimFunc&& func, auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; + + auto target = func->GetAttr(tvm::attr::kTarget); + CHECK(target.defined()) + << "MakePackedAPI: Require the target attribute"; + int target_device_type = target.value()->device_type; + std::string name_hint = global_symbol.value(); auto* func_ptr = func.CopyOnWrite(); @@ -68,7 +75,8 @@ PrimFunc MakePackedAPI(PrimFunc&& func, // The arguments of the function. Array args; // The device context - Var device_type("dev_type"), device_id("dev_id"); + Var device_id("dev_id"); + Integer device_type(target_device_type); // seq_init gives sequence of initialization // seq_check gives sequence of later checks after init std::vector seq_init, seq_check; @@ -195,17 +203,18 @@ PrimFunc MakePackedAPI(PrimFunc&& func, // Set device context if (vmap.count(device_id.get())) { PrimExpr node = StringImmNode::make("default"); - CHECK(vmap.count(device_type.get())); seq_check.push_back(AttrStmtNode::make( node, attr::device_context_id, device_id, nop)); seq_check.push_back(AttrStmtNode::make( node, attr::device_context_type, device_type, nop)); - Stmt set_device = IfThenElseNode::make( - device_type != kDLCPU, EvaluateNode::make(CallNode::make( - DataType::Int(32), intrinsic::tvm_call_packed, - {StringImmNode::make(runtime::symbol::tvm_set_device), - device_type, device_id}, CallNode::Intrinsic))); - body = SeqStmt({set_device, body}); + + if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) { + Stmt set_device = EvaluateNode::make(CallNode::make( + DataType::Int(32), intrinsic::tvm_call_packed, + {StringImmNode::make(runtime::symbol::tvm_set_device), + device_type, device_id}, CallNode::Intrinsic)); + body = SeqStmt({set_device, body}); + } } func_ptr->body = MergeNest( {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); diff --git a/tests/python/unittest/test_runtime_extension.py b/tests/python/unittest/test_runtime_extension.py index d9088b64168d..52fc8c233a12 100644 --- a/tests/python/unittest/test_runtime_extension.py +++ b/tests/python/unittest/test_runtime_extension.py @@ -39,10 +39,8 @@ def test_dltensor_compatible(): A[i + 1] = A[i] + 1 stmt = ib.get() - mod = tvm.testing.MakeAPILegacy(stmt, "arange", [Ab], 0, True) - mod = tvm.tir.transform.LowerTVMBuiltin()(mod) - f = tvm.target.codegen.build_module(mod, "stackvm") + f = tvm.build(mod, target="stackvm") a = tvm.nd.array(np.zeros(10, dtype=dtype)) aview = MyTensorView(a) f(aview) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 3de1d1679e70..76f96d4a3ba0 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -111,7 +111,7 @@ def test_llvm_lookup_intrin(): x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z]) ib.emit(x) body = ib.get() - func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 1, True) + func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 0, True) fcode = tvm.build(func, None, "llvm") diff --git a/tests/python/unittest/test_tir_analysis_verify_memory.py b/tests/python/unittest/test_tir_analysis_verify_memory.py index f993c915aa9c..b3625082f6ed 100644 --- a/tests/python/unittest/test_tir_analysis_verify_memory.py +++ b/tests/python/unittest/test_tir_analysis_verify_memory.py @@ -44,7 +44,7 @@ def lower(sch, args): f = tvm.tir.PrimFunc(arg_list, stmt).with_attr( "global_symbol", tvm.runtime.String("test")) mod = tvm.IRModule({"test": f}) - return tvm.tir.transform.MakePackedAPI()(mod) + return mod # All computations are bound. diff --git a/tests/python/unittest/test_tir_pass_inject_double_buffer.py b/tests/python/unittest/test_tir_pass_inject_double_buffer.py index 95a10547463c..6b04db30f6d5 100644 --- a/tests/python/unittest/test_tir_pass_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_pass_inject_double_buffer.py @@ -40,9 +40,10 @@ def test_double_buffer(): stmt = tvm.tir.ir_pass.Simplify(stmt) assert isinstance(stmt.body.body, tvm.tir.Allocate) assert stmt.body.body.extents[0].value == 2 - mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 2, True) + mod = tvm.IRModule({ + "db" : tvm.tir.PrimFunc([A.asobject(), C.asobject()], stmt) + }) f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] - count = [0] def count_sync(op): if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": diff --git a/tests/python/unittest/test_tir_pass_storage_flatten.py b/tests/python/unittest/test_tir_pass_storage_flatten.py index da9253f1dfca..88799c4736d9 100644 --- a/tests/python/unittest/test_tir_pass_storage_flatten.py +++ b/tests/python/unittest/test_tir_pass_storage_flatten.py @@ -92,7 +92,7 @@ def test_flatten_double_buffer(): stmt = tvm.tir.ir_pass.Simplify(stmt) assert isinstance(stmt.body.body, tvm.tir.Allocate) assert stmt.body.body.extents[0].value == 2 - mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 2, True) + mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 0, True) f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] diff --git a/tests/python/unittest/test_tir_transform_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py index 6f2bc65450be..7fd2593bd365 100644 --- a/tests/python/unittest/test_tir_transform_combine_context_call.py +++ b/tests/python/unittest/test_tir_transform_combine_context_call.py @@ -36,7 +36,10 @@ def device_context(dev_id): ib.emit(tvm.tir.call_extern ("int32", "fadd", device_context(0), A)) body = ib.get() - mod = tvm.testing.MakeAPILegacy(body, "func", [dev_type, n], 2, True) + mod = tvm.IRModule({ + "func" : tvm.tir.PrimFunc([dev_type, n], body) + }) + mod = tvm.tir.transform.CombineContextCall()(mod) assert mod["func"].body.value.dtype == "handle" diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 898b08ee91d5..7222a617ced5 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -35,8 +35,10 @@ def test_makeapi(): stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64) num_unpacked_args = 2 - f = tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr( - "tir.noalias", True).with_attr("global_symbol", tvm.runtime.String("myadd")) + f = tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt) + f = f.with_attr("global_symbol", "myadd") + f = f.with_attr("target", tvm.target.create("llvm")) + mod = tvm.IRModule.from_expr(f) f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"] assert(len(f.params) == 7) diff --git a/tests/python/unittest/test_tir_transform_prim_func_pass.py b/tests/python/unittest/test_tir_transform_prim_func_pass.py index f286bf06cce2..977f50e1942f 100644 --- a/tests/python/unittest/test_tir_transform_prim_func_pass.py +++ b/tests/python/unittest/test_tir_transform_prim_func_pass.py @@ -60,7 +60,7 @@ def fapply(f): del func # copy on write mod_hash = mod.__hash__() - mod = tvm.ir.transform.Sequential( + mod = tvm.transform.Sequential( [pidentity, tvm.tir.transform.NarrowDataType(32)])(mod._move()) assert mod_hash == mod.__hash__() assert func_hash == mod["main"].__hash__()