Skip to content

Commit

Permalink
[TIR][PASS] dtype rewrite for indexing variables (#5092)
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-project-ci committed Apr 2, 2020
1 parent 4195b2e commit 4e5c584
Show file tree
Hide file tree
Showing 16 changed files with 703 additions and 8 deletions.
9 changes: 9 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ class ConstIntBoundAnalyzer {
*/
ConstIntBound operator()(const PrimExpr& expr);

/*!
* \brief analyze the expr with the intermediate memorized to avoid redundant computation
* \param expr The expression of interest.
* \param bound The lookup table to store the intermediate results
* \return the result of the analysis.
*/
ConstIntBound operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);

/*!
* \brief Update constant int bound information of var.
*
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,15 @@ Stmt DecorateDeviceScope(Stmt stmt);
*/
Stmt HoistIfThenElse(Stmt stmt);

/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
* \note Run this pass after StorageFlatten.
* \param stmt The stmt to do datatype rewrite
* \param target_bits the bit of target datatype
* \return Transformed stmt.
*/
Stmt NarrowDataType(Stmt stmt, int target_bits);

/*!
* \brief Make an user callable API LoweredFunc.
*
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo();
*/
TVM_DLL Pass LowerWarpMemory();


/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
*
* \note Run this pass after StorageFlatten.
*
* \return The pass.
*/
TVM_DLL Pass NarrowDataType();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
1 change: 1 addition & 0 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def lower(sch,
# Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.NarrowDataType(stmt, 32)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ def __init__(self, dom, var, iter_type, thread_tag=""):
raise TypeError("dom need to be Range")

name = var if var is not None else "iter"
var = Var(name, dtype="int32") if not isinstance(var, Var) else var
dtype = "int32" if dom is None else dom.extent.dtype
var = Var(name, dtype=dtype) if not isinstance(var, Var) else var
self.__init_handle_by_constructor__(
_ffi_api.IterVar, dom, var, iter_type, thread_tag)

Expand Down
6 changes: 4 additions & 2 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def dtype(self):
def __getitem__(self, index):
t = DataType(self._content_type)
if t.lanes > 1:
index = _expr.Ramp(index * t.lanes, 1, t.lanes)
base = index * t.lanes
index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
return _expr.Load(self._content_type, self._buffer_var, index)

def __setitem__(self, index, value):
Expand All @@ -87,7 +88,8 @@ def __setitem__(self, index, value):
value.dtype, self._content_type))
t = DataType(self._content_type)
if t.lanes > 1:
index = _expr.Ramp(index * t.lanes, 1, t.lanes)
base = index * t.lanes
index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
self._builder.emit(_stmt.Store(self._buffer_var, value, index))


Expand Down
15 changes: 15 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,18 @@ def LowerWarpMemory():
The result pass
"""
return _ffi_api.LowerWarpMemory()


def NarrowDataType():
"""Narrow down PrimExpr datatype in stmt to target_bits.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
Note
----
Run this pass after StorageFlatten.
"""
return _ffi_api.NarrowDataType()
32 changes: 32 additions & 0 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,30 @@ class ConstIntBoundAnalyzer::Impl :
res = Intersect(res, info.bound);
}
}
if (bound_) {
const PrimExprNode* op = expr.as<PrimExprNode>();
auto val = bound_->find(op);
if (val != bound_->end()) {
CHECK(val->second->min_value == res.min_value &&
val->second->max_value == res.max_value)
<< "Detected bound for " << expr
<< "conflicts with memorization";
}
(*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
}
return res;
}

Entry VisitExpr_(const RampNode* op) final {
// op = {base + i * stride | 0 <= i < lanes}
// Entry(op) = Union(Entry(base + i * stride) | 0 <= i < lanes)
// Note that `base + i * stride` is linear w.r.t. `i`
// Entry(op) = Union(Entry(base + i * stride) | i = 0, i = lanes-1)
Entry a = VisitExpr(op->base);
Entry b = VisitExpr(op->base + (op->lanes - 1) * op->stride);
return Union(a, b);
}

Entry VisitExpr_(const CastNode* op) final {
Entry a = VisitExpr(op->value);
Entry b = Everything(op->dtype);
Expand Down Expand Up @@ -340,10 +361,13 @@ class ConstIntBoundAnalyzer::Impl :
}

private:
friend class ConstIntBoundAnalyzer;
// internal variable map
std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> var_map_;
// additional bound info
std::vector<BoundInfo> additional_info_;
// look up table for memorization
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound_{nullptr};
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
Expand Down Expand Up @@ -536,6 +560,14 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) {
return ConstIntBound(ret.min_value, ret.max_value);
}

ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound) {
impl_->bound_ = bound;
Entry ret = impl_->VisitExpr(expr);
impl_->bound_ = nullptr;
return ConstIntBound(ret.min_value, ret.max_value);
}

void ConstIntBoundAnalyzer::Update(const Var& var,
const ConstIntBound& info,
bool override) {
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin),
MakeValue(end),
ConstInt32(1),
llvm::ConstantInt::getSigned(GetLLVMType(end), 1),
op->loop_var,
op->body);
}
Expand Down
3 changes: 2 additions & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,8 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
CHECK(op->for_type == ForType::Serial);
}
CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
ConstInt32(1), op->loop_var, op->body);
llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1),
op->loop_var, op->body);
}


Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ Buffer BufferNode::make(Var data,
n->buffer_type = buffer_type;
if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) {
for (size_t i = 0; i < n->shape.size(); ++i) {
n->strides.push_back(Var("stride"));
n->strides.push_back(Var("stride", n->shape[i].dtype()));
}
}
return Buffer(n);
Expand Down
1 change: 1 addition & 0 deletions src/tir/pass/ffi_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,6 @@ REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse);
REGISTER_PASS(InferFragment)
REGISTER_PASS(NarrowDataType);
} // namespace tir
} // namespace tvm
2 changes: 1 addition & 1 deletion src/tir/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b
// If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
} else {
return ForNode::make(for_node->loop_var, 0, extent,
return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent,
for_node->for_type, for_node->device_api, body);
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/tir/pass/unroll_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ class LoopUnroller : public StmtExprMutator {
PrimExpr extent = tir::Simplify(op->extent);
const IntImmNode *v1 = extent.as<IntImmNode>();
int value = -1;
if (v1 != nullptr) {
// integers that do not fit in int32_t are treated as symbolic,
// as it's impossible to unroll such large loops
if (v1 != nullptr && v1->value <= std::numeric_limits<int>::max()) {
value = static_cast<int>(v1->value);
}
return value;
Expand Down
Loading

0 comments on commit 4e5c584

Please sign in to comment.