Skip to content

Commit

Permalink
add comment
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 19, 2022
1 parent 9ec0974 commit 2909a06
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2105,12 +2105,28 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
ICHECK(desc_loops.size() == static_cast<size_t>(n_desc_vars));
ICHECK(block_loops.size() == iter_types_block.size());

// We assume that the orders of iter_vars in the target and the desc block are consistent.
// Based on that assumption, the following logic supports arbitrary permutations of a loop order,
// such as

// for k:
// for i:
// for j:
// C[i, j] += A[i, k] * B[k, j]

// or

// for i:
// for j:
// for k:
// C[i, j] += A[i, k] * B[k, j]

int next_block_ind = block_loops.size() - 1;
for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) {
// Step 4.2. Find the corresponding loop of the i-th block var of desc
// Step 3.1. Find the corresponding loop of the i_desc-th block var of desc
const PrimExpr& desc_bind = desc_block->iter_values[i_desc];
const tir::ForNode* desc_loop = nullptr;
IterVarType iter_type_desc;
IterVarType iter_type_desc = iter_types_desc[i_desc];
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
// Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars
PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var);
Expand All @@ -2127,29 +2143,32 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,

const IntImmNode* int_desc_extent = desc_loop->extent.as<IntImmNode>();

const tir::ForNode* block_loop = nullptr;

// Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type
PrimExpr block_bind;
for (int i_block = next_block_ind; i_block >= 0; --i_block) {
if (iter_types_block[i_block] == iter_type_desc) {
next_block_ind = i_block - 1;
block_bind = block->iter_values[i_block];
for (int i = next_block_ind; i >= 0; --i) {
if (iter_types_block[i] == iter_type_desc) {
next_block_ind = i - 1;
block_bind = block->iter_values[i];
break;
}
}

if (!block_bind.defined()) return NullOpt;

// Step 3.3. Find the corresponding loop of the target block
for (int i = 0, n = block_loops.size(); i < n; ++i) {
// Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
if (!UsesVar(r,
[&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) {
block_loop = block_loops[i];
const IntImmNode* int_block_extent = block_loop->extent.as<IntImmNode>();
const IntImmNode* int_block_extent = block_loops[i]->extent.as<IntImmNode>();

// Check divisibility
if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) {
return NullOpt;
}

const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loops[i]];
auto it = ret->loop_map.find(block_loop_sref);
if (it == ret->loop_map.end()) {
ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
Expand Down

0 comments on commit 2909a06

Please sign in to comment.