Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Axis Dependency Tree aware code-gen and bmm example #28

Merged
merged 12 commits into from
Nov 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class AxisNode : public Object {
String GetName() const { return name; }
PrimExpr GetLength() const { return length; }
DataType GetIndexType() const { return length->dtype; }

virtual bool is_fixed() const = 0;

static constexpr const char* _type_key = "tir.sparse.Axis";
static constexpr const bool _type_has_method_sequal_reduce = true;
Expand Down Expand Up @@ -141,6 +143,10 @@ class DenseFixedAxisNode : public DenseAxisNode {
hash_reduce(from_sparse);
}

bool is_fixed() const {
return true;
}

static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
};
Expand Down Expand Up @@ -177,6 +183,10 @@ class DenseVariableAxisNode : public DenseAxisNode {
hash_reduce(indptr);
}

bool is_fixed() const {
return false;
}

static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
};
Expand Down Expand Up @@ -220,6 +230,10 @@ class SparseFixedAxisNode : public SparseAxisNode {
hash_reduce(nnz_cols);
}

bool is_fixed() const {
return true;
}

static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode);
};
Expand Down Expand Up @@ -262,6 +276,10 @@ class SparseVariableAxisNode : public SparseAxisNode {
hash_reduce(indices);
}

bool is_fixed() const {
return false;
}

static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode);
};
Expand All @@ -283,9 +301,9 @@ class SparseVariableAxis : public SparseAxis {
class AxisTreeNode : public Object {
public:
// unordered map that stores the parent relationship between axes.
Map<String, Optional<String>> parent;
Map<String, String> parent;
// unordered map that stores the children relationship between axes.
Map<Optional<String>, Array<String>> children;
Map<String, Array<String>> children;

void VisitAttrs(AttrVisitor* v) {
v->Visit("parent", &parent);
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,10 @@ TVM_DLL Pass ConvertForLoopsToSerial();

/*!
* \brief Lower SparseTIR to TIR.
* \param axis_tree The axis dependency tree.
* \return The pass.
*/
TVM_DLL Pass LowerSparseTIR();
TVM_DLL Pass LowerSparseTIR(AxisTree axis_tree);

} // namespace transform
} // namespace tir
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Optional
from . import _ffi_api
from . import function_pass as _fpass
from ..sparse import AxisTree


def Apply(ftransform):
Expand Down Expand Up @@ -751,12 +752,17 @@ def ConvertForLoopsToSerial():
return _ffi_api.ConvertForLoopsToSerial() # type: ignore


def LowerSparseTIR():
def LowerSparseTIR(axis_tree: AxisTree):
"""Lower SparseTIR to TIR

Parameters
----------
axis_tree : AxisTree
The axis dependency tree.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerSparseTIR() # type: ignore
return _ffi_api.LowerSparseTIR(axis_tree) # type: ignore
9 changes: 6 additions & 3 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,15 @@ AxisTree::AxisTree(Array<String> axis_names, Array<Optional<String>> axis_parent
"axis_parent_names "
"array.";
ObjectPtr<AxisTreeNode> node = make_object<AxisTreeNode>();
Map<String, Optional<String>> parent;
Map<Optional<String>, Array<String>> children;
Map<String, String> parent;
Map<String, Array<String>> children;
for (size_t i = 0; i < axis_names.size(); i++) {
// update parent map & children map
String axis_name = axis_names[i];
Optional<String> parent_name = axis_parent_names[i];
String parent_name("root");
if (axis_parent_names[i].defined()) {
parent_name = axis_parent_names[i].value();
}
parent.Set(axis_name, parent_name);

auto it = children.find(parent_name);
Expand Down
149 changes: 118 additions & 31 deletions src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <set>
#include <stack>
#include <utility>

#include "../schedule/analysis.h"
Expand Down Expand Up @@ -87,8 +89,8 @@ Map<Var, Buffer> UpdateBufferMap(PrimFunc f) {
*/
class IndexTransformer : public StmtExprMutator {
public:
explicit IndexTransformer(AccessAndDependencyCollector collector)
: collector_(std::move(collector)) {}
explicit IndexTransformer(AccessAndDependencyCollector collector, AxisTree axis_tree)
: collector_(std::move(collector)), axis_tree_(std::move(axis_tree)) {}

private:
/*!
Expand Down Expand Up @@ -281,43 +283,124 @@ class IndexTransformer : public StmtExprMutator {
sp_block->init.defined() ? VisitStmt(sp_block->init.value()) : Optional<Stmt>(NullOpt);
Stmt body = VisitStmt(sp_block->body);

// Step 2. Create the new outer loop vars.
Array<Var> loop_vars;
// Step 2. Create the new loop vars.
std::unordered_map<const VarNode*, PrimExpr> var_map;
loop_vars.reserve(n_iter);
Array<Var> all_loop_vars;
var_map.reserve(n_iter);
for (const SpIterVar& sp_iter : sp_block->sp_iter_vars) {
Var loop_var("v_" + sp_iter->var->name_hint);
loop_vars.push_back(loop_var);
all_loop_vars.push_back(loop_var);
var_map[sp_iter->var.get()] = loop_var;
}

// Step 3. Create block iters and iter bindings.
Array<IterVar> block_iters;
Array<PrimExpr> iter_bindings;
block_iters.reserve(n_iter);
iter_bindings.reserve(n_iter);
for (int i = 0; i < n_iter; ++i) {
block_iters.push_back(SpIterVarToIterVar(sp_block->sp_iter_vars[i], var_map));
iter_bindings.push_back(loop_vars[i]);
}
// Step 3. Collet block iters and iter bindings.
std::set<String> in_stack;
in_stack.insert("root");
/* A stack that stores block itervars in each block. */
std::stack<Array<IterVar>> block_iters_st;
/* A stack that stores itervar bindings in each block. */
std::stack<Array<PrimExpr>> iter_bindings_st;
/* A stack that stores generated loop vars in each block. */
std::stack<Array<Var>> loop_vars_st;
/* A stack that stores whether to place init block in each block. */
std::stack<bool> place_init_st;
/* An indicator that records whether init block has been set. */
bool init_set = false;
do {
/* Block itervars of current block. */
Array<IterVar> block_iters;
/* Itervar bindings of current block. */
Array<PrimExpr> iter_bindings;
/* Axis names of current block. */
Array<String> axis_names;
/* Generated loop vars of current block. */
Array<Var> loop_vars;
/* An indicator that records whether there is reduction axis in current block. */
bool has_reduction_var = false;
for (int i = 0; i < n_iter; ++i) {
SpIterVar sp_it_var = sp_block->sp_iter_vars[i];
String axis_name = sp_it_var->axis->name;
auto&& parent_axis = axis_tree_->parent.Get(axis_name);
CHECK(parent_axis.defined()) << "Sparse IterVar not defined in Axis Tree.";
String parent_axis_name = parent_axis.value();
bool is_fixed_axis = sp_it_var->axis->is_fixed();
/* Add itervar to current block when
* - it's not used yet (not in stack) and
* - it's parent axis was used in outer blocks or
* - it's an iterator to a fixed axis.
*/
if ((is_fixed_axis || in_stack.find(parent_axis_name) != in_stack.end()) &&
in_stack.find(axis_name) == in_stack.end()) {
loop_vars.push_back(all_loop_vars[i]);
axis_names.push_back(std::move(axis_name));
block_iters.push_back(SpIterVarToIterVar(sp_it_var, var_map));
iter_bindings.push_back(all_loop_vars[i]);
has_reduction_var |= sp_it_var->is_reduction;
}
Comment on lines +327 to +339
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feels like when a sparse block has three iterators of type df, dv, df, the lowered TIR will have loop order of df, df, dv?

I think we should break this loop once the if isn't satisfied 🤔. It's not expected to reorder the loops in lowering.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think reorder loops at Sparse TIR stage is not useful. Because all valid reorder schedules can be performed after lowering.

}

/* Tag axes in current block as "in-stack". */
for (const String&& axis_name : axis_names) {
in_stack.insert(std::move(axis_name));
}

/* Update stack. */
if (!block_iters.empty()) {
block_iters_st.push(std::move(block_iters));
iter_bindings_st.push(std::move(iter_bindings));
loop_vars_st.push(std::move(loop_vars));
if (init_set) {
place_init_st.push(false);
} else {
place_init_st.push(has_reduction_var);
init_set |= has_reduction_var;
}
} else {
break;
}
} while (true);

// Step 4. Generate the read-region and write-retion of the block.
Array<BufferRegion> reads{nullptr};
Array<BufferRegion> writes{nullptr};
GenerateReadWriteRegions(sp_block, &reads, &writes);

// Step 5. Create the block and block-realize
Map<String, ObjectRef> mapping;
mapping.Set("sparse", Bool(true));
Block block(block_iters, std::move(reads), std::move(writes), sp_block->name, std::move(body),
std::move(init), {}, {}, std::move(mapping));
BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block));

// Step 6. Create outer loops and the block binding.
Stmt loop = GenerateLoops(std::move(block_realize), block_iters, loop_vars);
// Step 5. Generate nested blocks and loops from innermost to outermost.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't got enough time to review yet. I'm going to think more on this part tomorrow.

int blk_counter = 0;
while (!block_iters_st.empty()) {
Array<IterVar> block_iters = std::move(block_iters_st.top());
Array<PrimExpr> iter_bindings = std::move(iter_bindings_st.top());
Array<Var> loop_vars = std::move(loop_vars_st.top());
bool place_init = place_init_st.top();
block_iters_st.pop();
iter_bindings_st.pop();
loop_vars_st.pop();
place_init_st.pop();

Map<String, ObjectRef> mapping;
mapping.Set("sparse", Bool(true));
String blk_name_hint = sp_block->name;
if (blk_counter != 0) {
blk_name_hint = blk_name_hint + "_" + std::to_string(blk_counter);
}
Block block(/*iter_vars=*/block_iters,
/*reads=*/reads,
/*writes=*/writes,
/*name_hint=*/blk_name_hint,
/*body=*/std::move(body),
/*init=*/place_init ? std::move(init) : NullOpt,
/*alloc_buffers=*/{},
/*match_buffers=*/{},
/*annotations=*/std::move(mapping),
/*span=*/sp_block->span);
BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block));
// Generate outer loop and the block binding.
Stmt loop = GenerateLoops(std::move(block_realize), block_iters, loop_vars);
body = loop;
blk_counter += 1;
}

return loop;
return body;
}

/*!
Expand Down Expand Up @@ -380,9 +463,10 @@ class IndexTransformer : public StmtExprMutator {
}

/*!
* \brief generated nested for loops for sparse block.
* \brief generated nested for-loops for sparse block.
* \param block_iters The iterators defined in sparse blocks.
* \param loop_vars The loop variables binded with block iterators.
* \return The outermost loop.
*/
Stmt GenerateLoops(Stmt body, const Array<IterVar>& block_iters, const Array<Var>& loop_vars) {
int n_iter = static_cast<int>(block_iters.size());
Expand All @@ -394,6 +478,7 @@ class IndexTransformer : public StmtExprMutator {
}

AccessAndDependencyCollector collector_;
AxisTree axis_tree_;
arith::Analyzer ana_;
std::unordered_set<const SparseBufferNode*> buffer_read_;
std::unordered_set<const SparseBufferNode*> buffer_write_;
Expand All @@ -411,11 +496,12 @@ Stmt WrapWithRootBlock(Stmt body) {
}

/*!
* \brief Rewrite the given primitive function
* \brief Rewrite the given primitive function.
* \param axis_tree The axis dependency tree.
* \param f The Sparse-TIR primitive function to lower.
* \return lowered primitive function in TIR.
*/
PrimFunc LowerSparseTIR(PrimFunc f) {
PrimFunc LowerSparseTIR(AxisTree axis_tree, PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
Expand All @@ -425,7 +511,7 @@ PrimFunc LowerSparseTIR(PrimFunc f) {
AccessAndDependencyCollector collector;
collector.Collect(f->body);
// Step 3. Lower indices.
fptr->body = IndexTransformer(collector)(std::move(f->body));
fptr->body = IndexTransformer(collector, axis_tree)(std::move(f->body));
// Step 4. Wrap the function body with a root block.
fptr->body = WrapWithRootBlock(std::move(fptr->body));
return f;
Expand All @@ -438,10 +524,11 @@ namespace transform {

/*!
* \brief The lowering pass from TIR to Sparse TIR.
* \param axis_tree The axis dependency tree.
*/
Pass LowerSparseTIR() {
Pass LowerSparseTIR(AxisTree axis_tree) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return LowerSparseTIR(std::move(f));
return LowerSparseTIR(std::move(axis_tree), std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerSparseTIR", {});
}
Expand Down
Loading