Skip to content

Commit

Permalink
[BugFix] Add field is_reduction for SpIterVar (#9)
Browse files Browse the repository at this point in the history
* [BugFix] Add field `is_reduction` for SpIterVar

* Formatting
  • Loading branch information
MasterJH5574 authored and yzh119 committed Feb 15, 2022
1 parent 99f6a35 commit 65ce747
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 34 deletions.
15 changes: 8 additions & 7 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class SparseAxis : public Axis {
*/
class SparseFixedAxisNode : public SparseAxisNode {
public:
Buffer indices;
Buffer indices;
/* fixed number of columns of current sparse axis. */
PrimExpr num_cols;

Expand Down Expand Up @@ -267,7 +267,6 @@ class SparseVariableAxis : public SparseAxis {
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
};


/*!
* \brief Axis Dependency Tree.
*/
Expand Down Expand Up @@ -314,9 +313,7 @@ class SparseBufferNode : public Object {
/* Data type */
runtime::DataType dtype;

inline int ndim() const {
return static_cast<int>(axes.size());
}
inline int ndim() const { return static_cast<int>(axes.size()); }

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &tree);
Expand Down Expand Up @@ -370,24 +367,28 @@ class SpIterVarNode : public Object {
Var var;
PrimExpr max_extent;
SpIterKind kind;
bool is_reduction;
Optional<Axis> axis;

void VisitAttrs(AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("max_extent", &max_extent);
v->Visit("axis", &axis);
v->Visit("is_reduction", &is_reduction);
v->Visit("kind", &kind);
}

bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const {
return equal(var, other->var) && equal(max_extent, other->max_extent) &&
equal(axis, other->axis) && equal(kind, other->kind);
equal(axis, other->axis) && equal(is_reduction, other->is_reduction) &&
equal(kind, other->kind);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(var);
hash_reduce(max_extent);
hash_reduce(axis);
hash_reduce(is_reduction);
hash_reduce(kind);
}

Expand All @@ -399,7 +400,7 @@ class SpIterVarNode : public Object {

class SpIterVar : public ObjectRef {
public:
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind,
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
Optional<Axis> axis = NullOpt);

/*!
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ class SpIterVar(Object):
kind : int
The kind of the SpIterVar
is_reduction : bool
Whether the SpIterVar is a reduction iterator
axis : Optional[Axis]
The axis over which the SpIterVar iterates. Required to be defined
Expand All @@ -222,6 +225,7 @@ class SpIterVar(Object):
var: Var
max_extent: PrimExpr
kind: int
is_reduction: bool
axis: Optional[Axis]

DenseFixed = 0
Expand All @@ -231,6 +235,6 @@ class SpIterVar(Object):

def __init__(self, var, max_extent, kind, axis=None):
self.__init_handle_by_constructor__(
_ffi_api.SpIterVar, var, max_extent, kind, axis # type: ignore
_ffi_api.SpIterVar, var, max_extent, kind, is_reduction, axis # type: ignore
)

45 changes: 19 additions & 26 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,12 @@ DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {

TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis")
.set_body_typed([](String name, PrimExpr length) {
return DenseFixedAxis(name, length);
});
TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis").set_body_typed([](String name, PrimExpr length) {
return DenseFixedAxis(name, length);
});

// DenseVariableAxis
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length,
Buffer indptr) {
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) {
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
node->name = std::move(name);
node->length = std::move(length);
Expand All @@ -61,8 +59,7 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
});

// SparseFixedAxis
SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices,
PrimExpr num_cols) {
SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
ObjectPtr<SparseFixedAxisNode> node = make_object<SparseFixedAxisNode>();
node->name = std::move(name);
node->length = std::move(length);
Expand All @@ -74,16 +71,14 @@ SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices,
TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
.set_body_typed([](String name, PrimExpr length, Buffer indices,
PrimExpr num_cols) {
.set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
return SparseFixedAxis(name, length, indices, num_cols);
});

// SparseVariableAxis
SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length,
Buffer indptr, Buffer indices) {
ObjectPtr<SparseVariableAxisNode> node =
make_object<SparseVariableAxisNode>();
SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indptr,
Buffer indices) {
ObjectPtr<SparseVariableAxisNode> node = make_object<SparseVariableAxisNode>();
node->name = std::move(name);
node->length = std::move(length);
node->indptr = std::move(indptr);
Expand All @@ -94,14 +89,12 @@ SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length,
TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis")
.set_body_typed([](String name, PrimExpr length, Buffer indptr,
Buffer indices) {
.set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) {
return SparseVariableAxis(name, length, indptr, indices);
});

// AxisTree
AxisTree::AxisTree(Array<Axis> axes,
Array<Optional<String>> axis_parent_names) {
AxisTree::AxisTree(Array<Axis> axes, Array<Optional<String>> axis_parent_names) {
CHECK_EQ(axes.size(), axis_parent_names.size())
<< "ValueError: The axes array should have the same length as axis_parent_names "
"array.";
Expand All @@ -121,9 +114,7 @@ AxisTree::AxisTree(Array<Axis> axes,
CHECK(node->axis_map.find(parent_name.value()) != node->axis_map.end())
<< "ValueError: Parent axis name doesn't exist.";
}
Axis parent_axis = (parent_name.get() != nullptr)
? node->axis_map[parent_name.value()]
: root;
Axis parent_axis = (parent_name.get() != nullptr) ? node->axis_map[parent_name.value()] : root;
node->parent[axis] = parent_axis;
if (node->children.find(parent_axis) != node->children.end()) {
node->children[parent_axis].push_back(axis);
Expand All @@ -139,8 +130,7 @@ AxisTree::AxisTree(Array<Axis> axes,
TVM_REGISTER_NODE_TYPE(AxisTreeNode);

TVM_REGISTER_GLOBAL("tir.sparse.AxisTree")
.set_body_typed([](Array<Axis> axes,
Array<Optional<String>> axis_parent_names) {
.set_body_typed([](Array<Axis> axes, Array<Optional<String>> axis_parent_names) {
return AxisTree(axes, axis_parent_names);
});

Expand All @@ -164,7 +154,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer")
});

// SpIterVar
SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, Optional<Axis> axis) {
SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
Optional<Axis> axis) {
ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();

if (kind != SpIterKind::kDenseFixed) {
Expand All @@ -175,15 +166,17 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, Optional
node->var = Var(std::move(name));
node->max_extent = std::move(max_extent);
node->kind = kind;
node->is_reduction = is_reduction;
node->axis = std::move(axis);
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(SpIterVarNode);

TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar")
.set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, Optional<Axis> axis) {
return SpIterVar(name, max_extent, kind, axis);
.set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
Optional<Axis> axis) {
return SpIterVar(name, max_extent, kind, is_reduction, axis);
});

} // namespace tir
Expand Down

0 comments on commit 65ce747

Please sign in to comment.