Skip to content

Commit

Permalink
Fuse&split (#408)
Browse files Browse the repository at this point in the history
* first commit

* fix cpplint

* fix

* remove redundant blank

* address comments

* lint

* address comments

* address comments

* address comments

* change fuse

* change split

* polish

* lint

* fix rebase

* fix bug and add tests

* clang format

* address comments

* format

* address comments

* address comments

* add symbolic test

* lint

* address comment

* check stage pipeline

* fix mypy

* check stage_pipeline

* Revert "check stage_pipeline"

This reverts commit a5a7f4f

* add stage_pipeline_assert

Co-authored-by: jinhongyi <[email protected]>
  • Loading branch information
jinhongyii and jinhongyi authored Jul 14, 2021
1 parent a4775c2 commit 587f42d
Show file tree
Hide file tree
Showing 13 changed files with 1,268 additions and 16 deletions.
12 changes: 12 additions & 0 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,18 @@ class IterSumExpr : public IterMapExpr {
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer);
/*!
* \brief Use IterVarMap detector to rewrite and simplify the indices
*
* \param indices The indices to detect pattern for.
* \param input_iters Map from variable to iterator's range.
* \param input_pred The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
*
* \return The indices after rewrite
*/
Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, bool require_bijective);

/*!
* \brief Apply the inverse of the affine transformation to the outputs.
Expand Down
31 changes: 31 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,18 @@ class ScheduleNode : public runtime::Object {
* \return The corresponding loop sref
*/
virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0;
/*!
* \brief Get the block srefs corresponding to an array of BlockRVs
* \param block_rvs The BlockRVs to be looked up
* \return The corresponding block srefs
*/
virtual Array<StmtSRef> GetSRefs(const Array<BlockRV>& block_rvs) const = 0;
/*!
* \brief Get the loop srefs corresponding to an array of LoopRVs
* \param loop_rvs The LoopRVs to be looked up
* \return The corresponding loop srefs
*/
virtual Array<StmtSRef> GetSRefs(const Array<LoopRV>& loop_rvs) const = 0;
/*!
* \brief Get the block/loop sref corresponding to the specific statement
* \param stmt The statement to be looked up
Expand Down Expand Up @@ -196,6 +208,25 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
/******** Schedule: loops manipulation ********/
/*!
* \brief Fuse consecutive loops into one. It requires:
* 1) The loops can't have annotations or thread bindings.
* 2) The (i+1)-th loop must be the only child of the i-th loop.
* 3) All loops must start with 0.
* \param loop_rvs The loops to be fused
* \return The fused loop
*/
virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs) = 0;
/*!
* \brief Split a specified loop into two or more with the specific factor.It requires:
* 1) The loop can't have annotation or thread binding.
* 2) The loop must start with 0.
* \param loop_rv The loop to be split
* \param factors The tiling factors, and at most one of which is -1, which means that
* factor is inferred.
* \return The loops after splitting
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<ExprRV>& factors) = 0;
/******** Schedule: compute location ********/
/*!
* \brief Inline a block into its consumer(s). It requires:
Expand Down
131 changes: 129 additions & 2 deletions python/tvm/tir/schedule/schedule.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=unused-import
"""The TensorIR schedule class"""
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
Expand All @@ -43,7 +43,7 @@ class BlockRV(Object):
"""A random variable that refers to a block"""


ExprRV = PrimExpr # A random variable that evaluates to an integer
ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer

RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # type: ignore # pylint: disable=invalid-name

Expand Down Expand Up @@ -257,6 +257,133 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]:
return _ffi_api_schedule.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member

########## Schedule: loops manipulation ##########
def fuse(self, *loops: List[LoopRV]) -> LoopRV:
"""Fuse a list of consecutive loops into one. It requires:
1) The loops can't have annotations or thread bindings.
2) The (i+1)-th loop must be the only child of the i-th loop.
3) All loops must start with 0.
Parameters
----------
*loops : List[LoopRV]
The loops to be fused
Returns
----------
fused_loop : LoopRV
The new loop after fusion
Examples
--------
Before fuse, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_fuse(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do fuse:
.. code-block:: python
sch = tir.Schedule(before_fuse, debug_mode=True)
i, j = sch.get_loops(sch.get_block("B"))
sch.fuse(i, j)
print(tvm.script.asscript(sch.mod["main"]))
After applying fuse, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_fuse(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, [128, 128])
for i0_i1_fused in tir.serial(0, 16384):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, tir.floordiv(i0_i1_fused, 128))
tir.bind(vj, tir.floormod(i0_i1_fused, 128))
tir.reads([A[vi, vj]])
tir.writes([B[vi, vj]])
B[vi, vj] = A[vi, vj] * 2.0
"""
return _ffi_api_schedule.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member

def split(
self,
loop: LoopRV,
factors: List[Optional[ExprRV]],
) -> List[LoopRV]:
"""Split a loop into a list of consecutive loops. It requires:
1) The loop can't have annotation or thread binding.
2) The loop must start with 0.
Predicates may be added to ensure the total loop numbers keeps unchanged.
In `factors`, at most one of the factors can be None or -1,
which will be automatically inferred.
Parameters
----------
loop : LoopRV
The loop to be split
factors: List[Optional[ExprRV]]
The splitting factors
Returns
----------
split_loops : List[LoopRV]
The new loops after split
Examples
--------
Before split, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_split(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do fuse:
.. code-block:: python
sch = tir.Schedule(before_split, debug_mode=True)
i, j = sch.get_loops(sch.get_block("B"))
sch.split(i, factors=[2, 64])
print(tvm.script.asscript(sch.mod["main"]))
After applying split, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_split(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, [128, 128])
for i0_outer, i0_inner, i1 in tir.grid(2, 64, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, ((i0_outer*64) + i0_inner))
tir.bind(vj, i1)
tir.reads([A[vi, vj]])
tir.writes([B[vi, vj]])
B[vi, vj] = A[vi, vj] * 2.0
"""
for i, factor in enumerate(factors):
if factor is None:
factors[i] = -1
return _ffi_api_schedule.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member

########## Schedule: compute location ##########
def compute_inline(self, block: BlockRV) -> None:
"""Inline a block into its consumer(s). It requires:
Expand Down
17 changes: 16 additions & 1 deletion src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,6 @@ class IterMapRewriter : public ExprMutator {
*/
Optional<IterSplitExpr> TryFuseIters(IterSumExpr expr) {
if (!is_zero(expr->base)) return NullOpt;
if (expr->args.size() == 1) return expr->args[0];
// select the iterators in order
std::vector<bool> visited(expr->args.size(), false);
std::vector<IterSplitExpr> flattened_iters, grouped_iters;
Expand Down Expand Up @@ -1086,6 +1085,22 @@ TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const Iter
return NormalizeIterMapToExpr(expr);
});

Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, bool require_bijective) {
Analyzer analyzer;
Array<IterSumExpr> rewrite =
DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer);
if (rewrite.empty()) {
return indices;
} else {
Array<PrimExpr> res;
res.reserve(rewrite.size());
IterMapToExprNormalizer converter(&analyzer);
for (const auto& expr : rewrite) res.push_back(converter.Convert(expr));
return res;
}
}

/*!
* \brief Divider to divide the bindings into two sets of bindings(outer and inner)
* such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X.
Expand Down
4 changes: 3 additions & 1 deletion src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floordiv(x * c1, x * c2), floordiv(c1, c2), c2.Eval()->value > 0);

TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0));

TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0));
Expand Down Expand Up @@ -881,7 +883,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {

TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0);
TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x));
TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y));

Expand Down
6 changes: 6 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ Array<StmtSRef> GetLoops(const StmtSRef& block_sref);
* \return A list of leaf blocks
*/
Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref);
/*!
* \brief Get the direct child Schedulable Stmt (Block and For)
* \param stmt the parent stmt.
* \return the list of child stmts
*/
Array<Stmt> GetChildren(const Stmt& stmt);

} // namespace tir
} // namespace tvm
Expand Down
30 changes: 30 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,5 +298,35 @@ Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent
throw;
}

Array<Stmt> GetChildren(const Stmt& stmt) {
/*! \note Nested SeqStmt is not allowed in schedule. */
Stmt body;
if (const auto* block = stmt.as<BlockNode>()) {
body = block->body;
} else if (const auto* loop = stmt.as<ForNode>()) {
body = loop->body;
} else {
LOG(FATAL) << "The Stmt can only be a Block or a For";
}
if (const auto* seq = body.as<SeqStmtNode>()) {
Array<Stmt> ret;
for (const Stmt& child : seq->seq) {
ICHECK(!child->IsInstance<SeqStmtNode>()) << "Nested SeqStmt is not allowed in schedule.";
if (child->IsInstance<BlockRealizeNode>()) {
ret.push_back(child.as<BlockRealizeNode>()->block);
} else {
ret.push_back(child);
}
}
return ret;
} else {
if (body->IsInstance<BlockRealizeNode>()) {
return Array<Stmt>{body.as<BlockRealizeNode>()->block};
} else {
return Array<Stmt>{body};
}
}
}

} // namespace tir
} // namespace tvm
28 changes: 28 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,34 @@ Array<LoopRV> ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) {
}

/******** Schedule: loops manipulation ********/

LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
TVM_TIR_SCHEDULE_BEGIN();
CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
StmtSRef fused_sref = tir::Fuse(state_, loop_srefs);
this->state_->DebugVerify();
return CreateRV<LoopRV>(fused_sref);
TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_);
throw;
}

Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv, const Array<ExprRV>& factor_rvs) {
TVM_TIR_SCHEDULE_BEGIN();
// Prepare for the splitting
StmtSRef loop_sref = this->GetSRef(loop_rv);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
Array<PrimExpr> factors;
factors.reserve(factor_rvs.size());
for (const ExprRV& factor_rv : factor_rvs) {
factors.push_back(this->Get(factor_rv));
}
Array<StmtSRef> results = tir::Split(state_, loop_sref, factors);
return CreateRV<LoopRV>(results);
TVM_TIR_SCHEDULE_END("split", this->error_render_level_);
throw;
}

/******** Schedule: compute location ********/

void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) {
Expand Down
Loading

0 comments on commit 587f42d

Please sign in to comment.