diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index e61c91a6a30bf..4c54ae49fee89 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -390,7 +390,7 @@ Stmt HoistIfThenElse(Stmt stmt); * \param stmt The stmt to do datatype rewrite * \return Transformed stmt. */ -Stmt DataTypeRewrite(Stmt stmt); +Stmt NarrowDataType(Stmt stmt); /*! * \brief Make an user callable API LoweredFunc. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 220de8eb12826..7ef5565fbd4a4 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -159,7 +159,7 @@ def lower(sch, # Phase 1 stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) - stmt = ir_pass.DataTypeRewrite(stmt) + stmt = ir_pass.NarrowDataType(stmt) stmt = ir_pass.CanonicalSimplify(stmt) for f in lower_phase1: stmt = f(stmt) diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index 5e921ae0155ba..39524c4a2eb57 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -177,6 +177,6 @@ REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(HoistIfThenElse); REGISTER_PASS(InferFragment) -REGISTER_PASS(DataTypeRewrite); +REGISTER_PASS(NarrowDataType); } // namespace tir } // namespace tvm diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/narrow_datatype.cc similarity index 88% rename from src/tir/pass/rewrite_datatype.cc rename to src/tir/pass/narrow_datatype.cc index b1246edbfcd3c..bd7a8d40c1280 100644 --- a/src/tir/pass/rewrite_datatype.cc +++ b/src/tir/pass/narrow_datatype.cc @@ -18,7 +18,7 @@ */ /*! - * \file rewrite_datatype.cc + * \file narrow_datatype.cc * \brief narrow the datatype of indexing vars */ @@ -30,6 +30,28 @@ namespace tvm { namespace tir { +// This pass narrows indexing expressions (like StoreNode::Index) +// that trivially fit into i32 to i32. Considering that i32 indices +// may be more efficient on some backends (while i64 may be more +// efficient on others, like llvm), we may want this pass when i32 +// indices are more efficient. +// +// For Var v, we determine its dtype by examining all the PrimExpr +// that contains v, denoted by E = {e_0 = v, e_1, e_2, ..., e_k}. +// If all expressions in E fit into i32, then we think v can be narrowed +// to i32. +// +// To make an indexing expression i32, we must make sure that every +// component of that expression is of dtype i32. So besides Var, we +// rewrite the following inside an indexing expression +// - Var +// - IntImm +// - Cast +// +// Algorithm: +// - Use DataTypeVisitor to determine whether a Var can be narrowed or not. +// - Use DataTypeRewritter to rewrite the components of an indexing expression. + using arith::Analyzer; using arith::IRMutatorWithAnalyzer; using arith::ConstIntBound; @@ -166,6 +188,9 @@ class DataTypeRewriter : public StmtExprMutator { Stmt VisitStmt_(const ForNode* op) final { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); + CHECK(op != nullptr) + << "Expected type to be ForNode" + << ", but get " << s->GetTypeKey(); PrimExpr e = VisitExpr(op->loop_var); Var var = Downcast(e); return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), @@ -177,7 +202,13 @@ class DataTypeRewriter : public StmtExprMutator { op->attr_key == attr::virtual_thread) { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); + CHECK(op != nullptr) + << "Expected type to be AttrStmtNode" + << ", but get " << s->GetTypeKey(); const IterVarNode* iv = op->node.as(); + CHECK(iv != nullptr) + << "Expected type to be IterVarNode" + << ", but get " << op->node->GetTypeKey(); PrimExpr e = VisitExpr(iv->var); Var var = Downcast(e); if (ivmap_.find(iv) == ivmap_.end()) { @@ -233,6 +264,9 @@ class DataTypeRewriter : public StmtExprMutator { if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) { PrimExpr e = StmtExprMutator::VisitExpr_(op); const CastNode* new_op = e.as(); + CHECK(new_op != nullptr) + << "Expected type to be CastNode" + << ", but get " << e->GetTypeKey(); return CastNode::make(visitor_.vmap[op], new_op->value); } return StmtExprMutator::VisitExpr_(op); @@ -298,6 +332,9 @@ DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=) PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { PrimExpr e = StmtExprMutator::VisitExpr_(op); op = e.as(); + CHECK(op != nullptr) + << "Expected type to be CallNode" + << ", but get " << e->GetTypeKey(); if (op->call_type == CallNode::PureIntrinsic) { if (op->name == intrinsic::tvm_if_then_else) { return if_then_else(op->args[0], op->args[1], op->args[2]); @@ -318,7 +355,7 @@ PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { return e; } -Stmt DataTypeRewrite(Stmt stmt) { +Stmt NarrowDataType(Stmt stmt) { return DataTypeRewriter()(stmt); } diff --git a/tests/python/unittest/test_tir_pass_rewrite_datatype.py b/tests/python/unittest/test_tir_pass_narrow_datatype.py similarity index 98% rename from tests/python/unittest/test_tir_pass_rewrite_datatype.py rename to tests/python/unittest/test_tir_pass_narrow_datatype.py index 69eee8cc19d6e..297dc166aa6a0 100644 --- a/tests/python/unittest/test_tir_pass_rewrite_datatype.py +++ b/tests/python/unittest/test_tir_pass_narrow_datatype.py @@ -32,7 +32,7 @@ def lower(sch, args): bounds = te.schedule.InferBound(sch) stmt = te.schedule.ScheduleOps(sch, bounds) stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False) - stmt = tvm.tir.ir_pass.DataTypeRewrite(stmt) + stmt = tvm.tir.ir_pass.NarrowDataType(stmt) return stmt