forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Axis Dependency Tree aware code-gen and bmm example #28
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
32ab233
upd
yzh119 c1f9f31
upd
yzh119 0ed63b0
upd
yzh119 11e91f3
upd
yzh119 b8fce54
upd
yzh119 9295e0e
upd
yzh119 d865a35
upd
yzh119 111e7eb
upd
yzh119 24e97cd
remove redundancy
yzh119 d8874ce
fix
yzh119 3ffa7b0
upd
yzh119 982636a
upd
yzh119 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,8 @@ | |
#include <tvm/tir/stmt_functor.h> | ||
#include <tvm/tir/transform.h> | ||
|
||
#include <set> | ||
#include <stack> | ||
#include <utility> | ||
|
||
#include "../schedule/analysis.h" | ||
|
@@ -87,8 +89,8 @@ Map<Var, Buffer> UpdateBufferMap(PrimFunc f) { | |
*/ | ||
class IndexTransformer : public StmtExprMutator { | ||
public: | ||
explicit IndexTransformer(AccessAndDependencyCollector collector) | ||
: collector_(std::move(collector)) {} | ||
explicit IndexTransformer(AccessAndDependencyCollector collector, AxisTree axis_tree) | ||
: collector_(std::move(collector)), axis_tree_(std::move(axis_tree)) {} | ||
|
||
private: | ||
/*! | ||
|
@@ -281,43 +283,124 @@ class IndexTransformer : public StmtExprMutator { | |
sp_block->init.defined() ? VisitStmt(sp_block->init.value()) : Optional<Stmt>(NullOpt); | ||
Stmt body = VisitStmt(sp_block->body); | ||
|
||
// Step 2. Create the new outer loop vars. | ||
Array<Var> loop_vars; | ||
// Step 2. Create the new loop vars. | ||
std::unordered_map<const VarNode*, PrimExpr> var_map; | ||
loop_vars.reserve(n_iter); | ||
Array<Var> all_loop_vars; | ||
var_map.reserve(n_iter); | ||
for (const SpIterVar& sp_iter : sp_block->sp_iter_vars) { | ||
Var loop_var("v_" + sp_iter->var->name_hint); | ||
loop_vars.push_back(loop_var); | ||
all_loop_vars.push_back(loop_var); | ||
var_map[sp_iter->var.get()] = loop_var; | ||
} | ||
|
||
// Step 3. Create block iters and iter bindings. | ||
Array<IterVar> block_iters; | ||
Array<PrimExpr> iter_bindings; | ||
block_iters.reserve(n_iter); | ||
iter_bindings.reserve(n_iter); | ||
for (int i = 0; i < n_iter; ++i) { | ||
block_iters.push_back(SpIterVarToIterVar(sp_block->sp_iter_vars[i], var_map)); | ||
iter_bindings.push_back(loop_vars[i]); | ||
} | ||
// Step 3. Collet block iters and iter bindings. | ||
std::set<String> in_stack; | ||
in_stack.insert("root"); | ||
/* A stack that stores block itervars in each block. */ | ||
std::stack<Array<IterVar>> block_iters_st; | ||
/* A stack that stores itervar bindings in each block. */ | ||
std::stack<Array<PrimExpr>> iter_bindings_st; | ||
/* A stack that stores generated loop vars in each block. */ | ||
std::stack<Array<Var>> loop_vars_st; | ||
/* A stack that stores whether to place init block in each block. */ | ||
std::stack<bool> place_init_st; | ||
/* An indicator that records whether init block has been set. */ | ||
bool init_set = false; | ||
do { | ||
/* Block itervars of current block. */ | ||
Array<IterVar> block_iters; | ||
/* Itervar bindings of current block. */ | ||
Array<PrimExpr> iter_bindings; | ||
/* Axis names of current block. */ | ||
Array<String> axis_names; | ||
/* Generated loop vars of current block. */ | ||
Array<Var> loop_vars; | ||
/* An indicator that records whether there is reduction axis in current block. */ | ||
bool has_reduction_var = false; | ||
for (int i = 0; i < n_iter; ++i) { | ||
SpIterVar sp_it_var = sp_block->sp_iter_vars[i]; | ||
String axis_name = sp_it_var->axis->name; | ||
auto&& parent_axis = axis_tree_->parent.Get(axis_name); | ||
CHECK(parent_axis.defined()) << "Sparse IterVar not defined in Axis Tree."; | ||
String parent_axis_name = parent_axis.value(); | ||
bool is_fixed_axis = sp_it_var->axis->is_fixed(); | ||
/* Add itervar to current block when | ||
* - it's not used yet (not in stack) and | ||
* - it's parent axis was used in outer blocks or | ||
* - it's an iterator to a fixed axis. | ||
*/ | ||
if ((is_fixed_axis || in_stack.find(parent_axis_name) != in_stack.end()) && | ||
in_stack.find(axis_name) == in_stack.end()) { | ||
loop_vars.push_back(all_loop_vars[i]); | ||
axis_names.push_back(std::move(axis_name)); | ||
block_iters.push_back(SpIterVarToIterVar(sp_it_var, var_map)); | ||
iter_bindings.push_back(all_loop_vars[i]); | ||
has_reduction_var |= sp_it_var->is_reduction; | ||
} | ||
} | ||
|
||
/* Tag axes in current block as "in-stack". */ | ||
for (const String&& axis_name : axis_names) { | ||
in_stack.insert(std::move(axis_name)); | ||
} | ||
|
||
/* Update stack. */ | ||
if (!block_iters.empty()) { | ||
block_iters_st.push(std::move(block_iters)); | ||
iter_bindings_st.push(std::move(iter_bindings)); | ||
loop_vars_st.push(std::move(loop_vars)); | ||
if (init_set) { | ||
place_init_st.push(false); | ||
} else { | ||
place_init_st.push(has_reduction_var); | ||
init_set |= has_reduction_var; | ||
} | ||
} else { | ||
break; | ||
} | ||
} while (true); | ||
|
||
// Step 4. Generate the read-region and write-retion of the block. | ||
Array<BufferRegion> reads{nullptr}; | ||
Array<BufferRegion> writes{nullptr}; | ||
GenerateReadWriteRegions(sp_block, &reads, &writes); | ||
|
||
// Step 5. Create the block and block-realize | ||
Map<String, ObjectRef> mapping; | ||
mapping.Set("sparse", Bool(true)); | ||
Block block(block_iters, std::move(reads), std::move(writes), sp_block->name, std::move(body), | ||
std::move(init), {}, {}, std::move(mapping)); | ||
BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block)); | ||
|
||
// Step 6. Create outer loops and the block binding. | ||
Stmt loop = GenerateLoops(std::move(block_realize), block_iters, loop_vars); | ||
// Step 5. Generate nested blocks and loops from innermost to outermost. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haven't got enough time to review yet. I'm going to think more on this part tomorrow. |
||
int blk_counter = 0; | ||
while (!block_iters_st.empty()) { | ||
Array<IterVar> block_iters = std::move(block_iters_st.top()); | ||
Array<PrimExpr> iter_bindings = std::move(iter_bindings_st.top()); | ||
Array<Var> loop_vars = std::move(loop_vars_st.top()); | ||
bool place_init = place_init_st.top(); | ||
block_iters_st.pop(); | ||
iter_bindings_st.pop(); | ||
loop_vars_st.pop(); | ||
place_init_st.pop(); | ||
|
||
Map<String, ObjectRef> mapping; | ||
mapping.Set("sparse", Bool(true)); | ||
String blk_name_hint = sp_block->name; | ||
if (blk_counter != 0) { | ||
blk_name_hint = blk_name_hint + "_" + std::to_string(blk_counter); | ||
} | ||
Block block(/*iter_vars=*/block_iters, | ||
/*reads=*/reads, | ||
/*writes=*/writes, | ||
/*name_hint=*/blk_name_hint, | ||
/*body=*/std::move(body), | ||
/*init=*/place_init ? std::move(init) : NullOpt, | ||
/*alloc_buffers=*/{}, | ||
/*match_buffers=*/{}, | ||
/*annotations=*/std::move(mapping), | ||
/*span=*/sp_block->span); | ||
BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block)); | ||
// Generate outer loop and the block binding. | ||
Stmt loop = GenerateLoops(std::move(block_realize), block_iters, loop_vars); | ||
body = loop; | ||
blk_counter += 1; | ||
} | ||
|
||
return loop; | ||
return body; | ||
} | ||
|
||
/*! | ||
|
@@ -380,9 +463,10 @@ class IndexTransformer : public StmtExprMutator { | |
} | ||
|
||
/*! | ||
* \brief generated nested for loops for sparse block. | ||
* \brief generated nested for-loops for sparse block. | ||
* \param block_iters The iterators defined in sparse blocks. | ||
* \param loop_vars The loop variables binded with block iterators. | ||
* \return The outermost loop. | ||
*/ | ||
Stmt GenerateLoops(Stmt body, const Array<IterVar>& block_iters, const Array<Var>& loop_vars) { | ||
int n_iter = static_cast<int>(block_iters.size()); | ||
|
@@ -394,6 +478,7 @@ class IndexTransformer : public StmtExprMutator { | |
} | ||
|
||
AccessAndDependencyCollector collector_; | ||
AxisTree axis_tree_; | ||
arith::Analyzer ana_; | ||
std::unordered_set<const SparseBufferNode*> buffer_read_; | ||
std::unordered_set<const SparseBufferNode*> buffer_write_; | ||
|
@@ -411,11 +496,12 @@ Stmt WrapWithRootBlock(Stmt body) { | |
} | ||
|
||
/*! | ||
* \brief Rewrite the given primitive function | ||
* \brief Rewrite the given primitive function. | ||
* \param axis_tree The axis dependency tree. | ||
* \param f The Sparse-TIR primitive function to lower. | ||
* \return lowered primitive function in TIR. | ||
*/ | ||
PrimFunc LowerSparseTIR(PrimFunc f) { | ||
PrimFunc LowerSparseTIR(AxisTree axis_tree, PrimFunc f) { | ||
// Only apply this pass to TIR that is not from TE schedules | ||
if (!IsFromLegacyTESchedule(f)) { | ||
PrimFuncNode* fptr = f.CopyOnWrite(); | ||
|
@@ -425,7 +511,7 @@ PrimFunc LowerSparseTIR(PrimFunc f) { | |
AccessAndDependencyCollector collector; | ||
collector.Collect(f->body); | ||
// Step 3. Lower indices. | ||
fptr->body = IndexTransformer(collector)(std::move(f->body)); | ||
fptr->body = IndexTransformer(collector, axis_tree)(std::move(f->body)); | ||
// Step 4. Wrap the function body with a root block. | ||
fptr->body = WrapWithRootBlock(std::move(fptr->body)); | ||
return f; | ||
|
@@ -438,10 +524,11 @@ namespace transform { | |
|
||
/*! | ||
* \brief The lowering pass from TIR to Sparse TIR. | ||
* \param axis_tree The axis dependency tree. | ||
*/ | ||
Pass LowerSparseTIR() { | ||
Pass LowerSparseTIR(AxisTree axis_tree) { | ||
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { | ||
return LowerSparseTIR(std::move(f)); | ||
return LowerSparseTIR(std::move(axis_tree), std::move(f)); | ||
}; | ||
return CreatePrimFuncPass(pass_func, 0, "tir.LowerSparseTIR", {}); | ||
} | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feels like when a sparse block has three iterators of type
df, dv, df
, the lowered TIR will have loop order ofdf, df, dv
?I think we should break this loop once the
if
isn't satisfied 🤔. It's not expected to reorder the loops in lowering.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I think reorder loops at Sparse TIR stage is not useful. Because all valid reorder schedules can be performed after lowering.