Skip to content

Commit

Permalink
Resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-project-ci committed Mar 21, 2020
1 parent c0244d5 commit 3ba40c2
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 15 deletions.
3 changes: 2 additions & 1 deletion include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,10 @@ Stmt HoistIfThenElse(Stmt stmt);
/*!
* \brief Narrow down PrimExpr datatype in stmt
* \param stmt The stmt to do datatype rewrite
* \param target_bits the bit of target datatype
* \return Transformed stmt.
*/
Stmt NarrowDataType(Stmt stmt);
Stmt NarrowDataType(Stmt stmt, int target_bits);

/*!
* \brief Make an user callable API LoweredFunc.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def lower(sch,
# Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.NarrowDataType(stmt)
stmt = ir_pass.NarrowDataType(stmt, 32)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
Expand Down
34 changes: 22 additions & 12 deletions src/tir/pass/narrow_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,25 @@ using arith::Analyzer;
using arith::IRMutatorWithAnalyzer;
using arith::ConstIntBound;

class DataTypeRewriter;

class DataTypeVisitor final : public StmtExprVisitor {
public:
explicit DataTypeVisitor(int target_bits)
: bits_(target_bits), target_bits_(target_bits) {}

void VisitExpr(const PrimExpr& e) {
if (e.dtype().is_int()) {
int bits = 64;
if (e.dtype().bits() <= 32 ||
analyzer_.CanProve(e <= max_value(DataType::Int(32)) &&
e >= min_value(DataType::Int(32)))) {
bits = 32;
int bits = max_bits_;
ConstIntBound bound = analyzer_.const_int_bound(e);
int64_t ubound = Downcast<IntImm, PrimExpr>(max_value(DataType::Int(target_bits_)))->value;
int64_t lbound = Downcast<IntImm, PrimExpr>(min_value(DataType::Int(target_bits_)))->value;
if (e.dtype().bits() <= target_bits_ ||
(bound->max_value <= ubound && bound->min_value >= lbound)) {
bits = target_bits_;
}
int tmp = bits_;
bits_ = bits > bits_ ? bits : bits_;
int tmp = bits > bits_ ? bits : bits_;
std::swap(bits_, tmp);
StmtExprVisitor::VisitExpr(e);
bits_ = tmp;
std::swap(bits_, tmp);
} else {
StmtExprVisitor::VisitExpr(e);
}
Expand Down Expand Up @@ -152,14 +155,20 @@ class DataTypeVisitor final : public StmtExprVisitor {
arith::Analyzer analyzer_;

private:
// the maximum possible bits, which serves as an init value
static constexpr const int max_bits_ = 64;
// the maximum possible bit of the current expression's return dtype
int bits_;
// the target bits
int target_bits_;
// the extent of vars to be rewritten
std::unordered_map<const VarNode*, DataType> vextent_;
};

class DataTypeRewriter : public StmtExprMutator {
public:
explicit DataTypeRewriter(int target_bits): visitor_(target_bits) {}

Stmt operator()(Stmt s) {
visitor_(s);
for (auto i = visitor_.vmap.begin(), last = visitor_.vmap.end(); i != last;) {
Expand Down Expand Up @@ -298,6 +307,7 @@ class DataTypeRewriter : public StmtExprMutator {
// a map from IterVar before rewrite to that after rewrite,
// ensures one old IterVar maps to exactly one new IterVar
std::unordered_map<const IterVarNode*, IterVar> ivmap_;
// indicator of LoadNode::index and StoreNode::index
bool is_index_{false};
};

Expand Down Expand Up @@ -355,8 +365,8 @@ PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) {
return e;
}

Stmt NarrowDataType(Stmt stmt) {
return DataTypeRewriter()(stmt);
Stmt NarrowDataType(Stmt stmt, int target_bits) {
return DataTypeRewriter(target_bits)(stmt);
}

} // namespace tir
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_tir_pass_narrow_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def lower(sch, args):
bounds = te.schedule.InferBound(sch)
stmt = te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False)
stmt = tvm.tir.ir_pass.NarrowDataType(stmt)
stmt = tvm.tir.ir_pass.NarrowDataType(stmt, 32)
return stmt


Expand Down

0 comments on commit 3ba40c2

Please sign in to comment.