Skip to content

Commit

Permalink
[Arith] Simplify MatchFusePattern in InverseAffineMap (#8427)
Browse files Browse the repository at this point in the history
* [Arith] Simplify MatchFusePattern in InverseAffineMap

* fix
  • Loading branch information
vinx13 committed Jul 9, 2021
1 parent 0fa4396 commit 683c5eb
Showing 1 changed file with 14 additions and 31 deletions.
45 changes: 14 additions & 31 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1425,16 +1425,15 @@ class InverseAffineIterMapTransformer {
return;
}

// Case 2: If the sum expression has multiple components, match the fuse pattern and then split
// Case 2: If the sum expression has multiple components, check the fuse pattern and then split
// the sum expression for each components.
// For example, consider the iterator i1[dom = (0, 16)], i2[dom = (0, 8)], fusing i1 and i2
// we will have i1_i2_fused[dom = (0, 64)]. During back propagation, we need to split the
// propagated value to get the corresponding components of i1 and i2, which are
// floordiv(i1_i2_fused, 8) and floormod(i1_i2_fused, 8), respectively.
Array<IterSplitExpr> splits = MatchFusePattern(iter_map_expr);
ICHECK(!splits.empty());

for (const IterSplitExpr& split : splits) {
CheckFusePattern(iter_map_expr);
for (size_t i = iter_map_expr->args.size(); i > 0; i--) {
const IterSplitExpr& split = iter_map_expr->args[i - 1];
backprop_.Set(split,
backprop_.at(split) + floormod(floordiv(input, split->scale), split->extent));
}
Expand Down Expand Up @@ -1485,33 +1484,17 @@ class InverseAffineIterMapTransformer {
}
}

Array<IterSplitExpr> MatchFusePattern(const IterSumExpr sum_expr) {
IntImm base_scale(nullptr);
size_t base_index = 0;
for (size_t i = 0; i < sum_expr->args.size(); ++i) {
if (const auto* op = sum_expr->args[i]->scale.as<IntImmNode>()) {
if (!base_scale.defined() || op->value < base_scale->value) {
base_scale = GetRef<IntImm>(op);
base_index = i;
}
}
}
ICHECK(base_scale.defined());
std::vector<IterSplitExpr> iters;
std::vector<bool> visited(sum_expr->args.size(), false);
PrimExpr expected_scale = base_scale;
for (size_t i = 0; i < sum_expr->args.size(); i++) {
size_t j = i == 0 ? base_index : 0;
for (; j < sum_expr->args.size(); ++j) {
if (!visited[j] && analyzer_->CanProveEqual(sum_expr->args[j]->scale, expected_scale))
break;
}
ICHECK(j != sum_expr->args.size());
visited[j] = true;
iters.push_back(sum_expr->args[j]);
expected_scale *= sum_expr->args[j]->extent;
/*
* \brief Check the fuse pattern of sum_expr. We assume components of sum_expr is sorted in
* descending order of lower_factor.
*/
void CheckFusePattern(const IterSumExpr sum_expr) {
ICHECK(sum_expr->args.size());
PrimExpr expected_scale = sum_expr->args.back()->scale;
for (size_t i = sum_expr->args.size(); i > 0; i--) {
ICHECK(analyzer_->CanProveEqual(sum_expr->args[i - 1]->scale, expected_scale));
expected_scale *= sum_expr->args[i - 1]->extent;
}
return iters;
}

Analyzer* analyzer_;
Expand Down

0 comments on commit 683c5eb

Please sign in to comment.