Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Add field is_reduction for SpIterVar #9

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class SparseAxis : public Axis {
*/
class SparseFixedAxisNode : public SparseAxisNode {
public:
Buffer indices;
Buffer indices;
/* fixed number of columns of current sparse axis. */
PrimExpr num_cols;

Expand Down Expand Up @@ -267,7 +267,6 @@ class SparseVariableAxis : public SparseAxis {
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
};


/*!
* \brief Axis Dependency Tree.
*/
Expand Down Expand Up @@ -314,9 +313,7 @@ class SparseBufferNode : public Object {
/* Data type */
runtime::DataType dtype;

inline int ndim() const {
return static_cast<int>(axes.size());
}
inline int ndim() const { return static_cast<int>(axes.size()); }

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &tree);
Expand Down Expand Up @@ -370,24 +367,28 @@ class SpIterVarNode : public Object {
Var var;
PrimExpr max_extent;
SpIterKind kind;
bool is_reduction;
Optional<Axis> axis;

void VisitAttrs(AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("max_extent", &max_extent);
v->Visit("axis", &axis);
v->Visit("is_reduction", &is_reduction);
v->Visit("kind", &kind);
}

bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const {
return equal(var, other->var) && equal(max_extent, other->max_extent) &&
equal(axis, other->axis) && equal(kind, other->kind);
equal(axis, other->axis) && equal(is_reduction, other->is_reduction) &&
equal(kind, other->kind);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(var);
hash_reduce(max_extent);
hash_reduce(axis);
hash_reduce(is_reduction);
hash_reduce(kind);
}

Expand All @@ -399,7 +400,7 @@ class SpIterVarNode : public Object {

class SpIterVar : public ObjectRef {
public:
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind,
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
Optional<Axis> axis = NullOpt);

/*!
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ class SpIterVar(Object):

kind : int
The kind of the SpIterVar

is_reduction : bool
Whether the SpIterVar is a reduction iterator

axis : Optional[Axis]
The axis over which the SpIterVar iterates. Required to be defined
Expand All @@ -222,6 +225,7 @@ class SpIterVar(Object):
var: Var
max_extent: PrimExpr
kind: int
is_reduction: bool
axis: Optional[Axis]

DenseFixed = 0
Expand All @@ -231,6 +235,6 @@ class SpIterVar(Object):

def __init__(self, var, max_extent, kind, axis=None):
self.__init_handle_by_constructor__(
_ffi_api.SpIterVar, var, max_extent, kind, axis # type: ignore
_ffi_api.SpIterVar, var, max_extent, kind, is_reduction, axis # type: ignore
)

45 changes: 19 additions & 26 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,12 @@ DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {

TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis")
.set_body_typed([](String name, PrimExpr length) {
return DenseFixedAxis(name, length);
});
TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis").set_body_typed([](String name, PrimExpr length) {
return DenseFixedAxis(name, length);
});

// DenseVariableAxis
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length,
Buffer indptr) {
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) {
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
node->name = std::move(name);
node->length = std::move(length);
Expand All @@ -61,8 +59,7 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
});

// SparseFixedAxis
SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices,
PrimExpr num_cols) {
SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
ObjectPtr<SparseFixedAxisNode> node = make_object<SparseFixedAxisNode>();
node->name = std::move(name);
node->length = std::move(length);
Expand All @@ -74,16 +71,14 @@ SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices,
TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
.set_body_typed([](String name, PrimExpr length, Buffer indices,
PrimExpr num_cols) {
.set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
return SparseFixedAxis(name, length, indices, num_cols);
});

// SparseVariableAxis
SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length,
Buffer indptr, Buffer indices) {
ObjectPtr<SparseVariableAxisNode> node =
make_object<SparseVariableAxisNode>();
SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indptr,
Buffer indices) {
ObjectPtr<SparseVariableAxisNode> node = make_object<SparseVariableAxisNode>();
node->name = std::move(name);
node->length = std::move(length);
node->indptr = std::move(indptr);
Expand All @@ -94,14 +89,12 @@ SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length,
TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis")
.set_body_typed([](String name, PrimExpr length, Buffer indptr,
Buffer indices) {
.set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) {
return SparseVariableAxis(name, length, indptr, indices);
});

// AxisTree
AxisTree::AxisTree(Array<Axis> axes,
Array<Optional<String>> axis_parent_names) {
AxisTree::AxisTree(Array<Axis> axes, Array<Optional<String>> axis_parent_names) {
CHECK_EQ(axes.size(), axis_parent_names.size())
<< "ValueError: The axes array should have the same length as axis_parent_names "
"array.";
Expand All @@ -121,9 +114,7 @@ AxisTree::AxisTree(Array<Axis> axes,
CHECK(node->axis_map.find(parent_name.value()) != node->axis_map.end())
<< "ValueError: Parent axis name doesn't exist.";
}
Axis parent_axis = (parent_name.get() != nullptr)
? node->axis_map[parent_name.value()]
: root;
Axis parent_axis = (parent_name.get() != nullptr) ? node->axis_map[parent_name.value()] : root;
node->parent[axis] = parent_axis;
if (node->children.find(parent_axis) != node->children.end()) {
node->children[parent_axis].push_back(axis);
Expand All @@ -139,8 +130,7 @@ AxisTree::AxisTree(Array<Axis> axes,
TVM_REGISTER_NODE_TYPE(AxisTreeNode);

TVM_REGISTER_GLOBAL("tir.sparse.AxisTree")
.set_body_typed([](Array<Axis> axes,
Array<Optional<String>> axis_parent_names) {
.set_body_typed([](Array<Axis> axes, Array<Optional<String>> axis_parent_names) {
return AxisTree(axes, axis_parent_names);
});

Expand All @@ -164,7 +154,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer")
});

// SpIterVar
SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, Optional<Axis> axis) {
SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
Optional<Axis> axis) {
ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();

if (kind != SpIterKind::kDenseFixed) {
Expand All @@ -175,15 +166,17 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, Optional
node->var = Var(std::move(name));
node->max_extent = std::move(max_extent);
node->kind = kind;
node->is_reduction = is_reduction;
node->axis = std::move(axis);
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(SpIterVarNode);

TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar")
.set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, Optional<Axis> axis) {
return SpIterVar(name, max_extent, kind, axis);
.set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
Optional<Axis> axis) {
return SpIterVar(name, max_extent, kind, is_reduction, axis);
});

} // namespace tir
Expand Down