Skip to content

Commit

Permalink
[Refactor] Refactor Unittest (#48)
Browse files Browse the repository at this point in the history
* upd

* remove redundancy
  • Loading branch information
yzh119 committed Jan 21, 2022
1 parent 2b542f5 commit 53dda9a
Show file tree
Hide file tree
Showing 10 changed files with 974 additions and 945 deletions.
4 changes: 2 additions & 2 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ class Attach(SpecialStmt):
def __init__(self):
def attach_axis(
parent: Axis,
orig: Axis,
orig: DenseVariableAxis,
nnz: PrimExpr,
indptr_var: tvm.tir.Var,
idtype: str = "int32",
Expand All @@ -967,7 +967,7 @@ def attach_axis(
f"`attach_axis` expected assign to only one var, but got {names}", span
)

indptr_len = orig.nnz + 1
indptr_len = orig.parent.length + 1
indptr_buf = tvm.tir.decl_buffer(
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
)
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def idtype(self):
def nnz(self):
return _ffi_api.GetNNZ(self)

@property
def parent(self):
return _ffi_api.GetParent(self)


@tvm._ffi.register_object("tir.sparse.DenseAxis")
class DenseAxis(Axis):
Expand Down Expand Up @@ -168,9 +172,9 @@ class AttachedAxis(DenseVariableAxis):
nnz : PrimExpr
indptr : PrimExpr

def __init__(self, name, parent, length, nnz, indptr):
def __init__(self, name, parent, orig, nnz, indptr):
self.__init_handle_by_constructor__(
_ffi_api.AttachedAxis, name, parent, length, nnz, indptr
_ffi_api.AttachedAxis, name, parent, orig, nnz, indptr
)


Expand Down
2 changes: 2 additions & 0 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.GetAxisIndexType").set_body_typed([](Axis axis)

TVM_REGISTER_GLOBAL("tir.sparse.GetNNZ").set_body_typed([](Axis axis) { return axis->GetNNZ(); });

TVM_REGISTER_GLOBAL("tir.sparse.GetParent").set_body_typed([](Axis axis) { return axis->GetParentAxis(); });

/******** AxisNode ********/

std::tuple<PrimExpr, PrimExpr> AxisNode::GetOffsetExtent(SparseCtx* ctx) const {
Expand Down
14 changes: 10 additions & 4 deletions src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,15 @@ class SparseBlockCtx : public SparseCtx {
Axis orig = group[j];
SetOffset(orig, offset);
if (j > 0) {
// TODO(zihao): support more than sv axis.
offset = lower_bound(Downcast<SparseVariableAxis>(orig)->indptr->data, offset,
Integer(0), orig->GetNNZ());
Buffer indptr;
if (auto sv_axis = orig.as<SparseVariableAxisNode>()) {
indptr = sv_axis->indptr;
} else if (auto dv_axis = orig.as<DenseVariableAxisNode>()) {
indptr = dv_axis->indptr;
} else {
throw;
}
offset = upper_bound(indptr->data, offset, Integer(0), indptr->shape[0]) - 1;
}
}
for (size_t j = 0; j < group.size(); ++j) {
Expand Down Expand Up @@ -379,7 +385,7 @@ class IndexTransformer : public StmtExprMutator {
*/
IterVar SpIterVarToIterVar(const SpIterVar& sp_iter, Map<Var, PrimExpr> var_map) {
// Substitute the iteration vars in the expression with the loop vars.
return IterVar(Range::FromMinExtent(0, Substitute(sp_blk_ctx_.GetIterExtent(sp_iter), var_map)),
return IterVar(Range::FromMinExtent(0, sp_blk_ctx_.GetIterExtent(sp_iter)),
sp_iter->var, sp_iter->is_reduction ? kCommReduce : kDataPar);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch as th
from tvm.script import tir as T
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from lowered_tir import lowered_rgcn_forward
from sparse_tir_scripts import rgcn_forward


class TorchOpTimer(object):
Expand Down Expand Up @@ -63,69 +65,6 @@ def prepare_graph(g, ntype=None):
return g


@T.prim_func
def rgcn(
etype: T.handle,
w: T.handle,
x: T.handle,
y: T.handle,
indptr: T.handle,
indices: T.handle,
n: T.int32,
r: T.int32,
feat_size: T.int32,
nnz: T.int32
):
I = T.dense_fixed(n)
J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32")
R = T.dense_fixed(r)
F_in = T.dense_fixed(feat_size)
F_out = T.dense_fixed(feat_size)
E = T.match_sparse_buffer(etype, (I, J), "int32")
W = T.match_sparse_buffer(w, (R, F_out, F_in), "float32")
X = T.match_sparse_buffer(x, (T.dense(J), F_in), "float32")
Y = T.match_sparse_buffer(y, (I, F_out), "float32")
T.func_attr({"global_symbol": "main", "tir.noalias": True})
with T.iter([I, F_out, J, F_in], "SSRR", "rgcn-forward") as [
vi, vout, vj, vin,
]:
with T.init():
Y[vi, vout] = 0.
Y[vi, vout] = Y[vi, vout] + W[E[vi, vj], vout, vin] * X[vj, vin]


@T.prim_func
def lowered_rgcn(etype: T.handle, w: T.handle, x: T.handle, y: T.handle, indptr: T.handle, indices: T.handle, n: T.int32, r: T.int32, feat_size: T.int32, nnz: T.int32) -> None:
E_data = T.match_buffer(etype, [nnz], dtype="int32")
W_data = T.match_buffer(w, [r * feat_size * feat_size], dtype="float32")
X_data = T.match_buffer(x, [n * feat_size], dtype="float32")
Y_data = T.match_buffer(y, [n * feat_size], dtype="float32")
J_indptr = T.match_buffer(indptr, [n + 1], dtype="int32")
J_indices = T.match_buffer(indices, [nnz], dtype="int32")
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for v_vi, v_vout in T.grid(n, feat_size):
with T.block("rgcn-forward_0"):
vi, vout = T.axis.remap("SS", [v_vi, v_vout])
T.reads(J_indptr[0: n + 1], J_indices[0: nnz], E_data[0: nnz], W_data[0: r *
feat_size * feat_size], X_data[0: n * feat_size], Y_data[0: n * feat_size])
T.writes(Y_data[0: n * feat_size])
T.block_attr({"sparse": True})
for v_vj in T.serial(J_indptr[v_vi + 1] - J_indptr[v_vi]):
for v_vin in T.serial(feat_size):
with T.block("rgcn-forward_1"):
vj, vin = T.axis.remap("RR", [v_vj, v_vin])
T.reads(J_indptr[0: n + 1], J_indices[0: nnz], E_data[0: nnz], W_data[0: r *
feat_size * feat_size], X_data[0: n * feat_size], Y_data[0: n * feat_size])
T.writes(Y_data[0: n * feat_size])
T.block_attr({"sparse": True})
with T.init():
Y_data[vi * feat_size + vout] = T.float32(0)
Y_data[vi * feat_size + vout] = Y_data[vi * feat_size + vout] + W_data[(
E_data[J_indptr[vi] + vj] * feat_size + vout) * feat_size + vin] * X_data[J_indices[J_indptr[vi] + vj] * feat_size + vin]


def test_rgcn(g: DGLHeteroGraph):
feat_size = 16
g = g.to(0)
Expand Down Expand Up @@ -180,9 +119,9 @@ def msg_func(edges):
print("dgl high-mem:\t\t", accum / (total - cold_start))

# tir
mod = tvm.IRModule.from_expr(rgcn)
mod = tvm.IRModule.from_expr(rgcn_forward)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_rgcn, True)
tvm.ir.assert_structural_equal(mod["main"], lowered_rgcn_forward, True)

N, R, FEAT_SIZE, NNZ = mod["main"].params[-4:]
sch = tir.Schedule(
Expand Down
Loading

0 comments on commit 53dda9a

Please sign in to comment.