Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix and more test for axis fusion, new workload #50

Merged
merged 2 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def dense_variable(
)

length, nnz = shape
indptr_len = parent_axis.length + 1
indptr_len = parent_axis.nnz + 1
indptr_buf = tvm.tir.decl_buffer(
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
)
Expand Down
6 changes: 5 additions & 1 deletion src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ class IndexTransformer : public StmtExprMutator {

const Optional<Var>& loop_var = axis2loop_var.Get(axis->GetParentAxis().value());
CHECK(loop_var.defined()) << "ValueError: The parent axis of " << axis
<< "does not appear in the sparse block";
<< " does not appear in the sparse block";

if (LoopVarAppears(loop_var.value())) {
return true;
Expand Down Expand Up @@ -559,6 +559,10 @@ class IndexTransformer : public StmtExprMutator {
for (const SpIterVar& sp_iter_var : sp_block->sp_iter_vars) {
Var loop_var("v_" + sp_iter_var->var->name_hint);
var_map.Set(sp_iter_var->var, loop_var);
if (auto fused_axis = sp_iter_var->axis.as<FusedAxisNode>()) {
// handle the special case of fused_axis
axis2loop_var.Set(fused_axis->group[fused_axis->index], loop_var);
}
axis2loop_var.Set(sp_iter_var->axis, loop_var);
}

Expand Down
54 changes: 54 additions & 0 deletions tests/python/sparsetir/bench_rgcn_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from dgl.heterograph import DGLHeteroGraph
import tvm
import tvm.testing
import tvm.tir as tir
import scipy.sparse as sp
import numpy as np
import dgl
import dgl.function as fn
import torch as th
from tvm.script import tir as T
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset


@T.prim_func
def rgcn_hetero_forward(
offset_ntype: T.handle,
w: T.handle,
x: T.handle,
y: T.handle,
indptr_i: T.handle,
indptr_j: T.handle,
indices_j: T.handle,
n: T.int32,
r: T.int32,
feat_size: T.int32,
nnz_i: T.int32,
nnz_j: T.int32
):
I_flatten = T.dense_fixed(n)
R = T.dense_fixed(r)
I = T.dense_variable(R, (n, nnz_i), indptr_i, "int32")
J = T.sparse_variable(I, (n, nnz_j), (indptr_j, indices_j), "int32")
F_in = T.dense_fixed(feat_size)
F_out = T.dense_fixed(feat_size)
offset = T.match_sparse_buffer(offset_ntype, (R,), "int32")
W = T.match_sparse_buffer(w, (R, F_out, F_in), "float32")
X = T.match_sparse_buffer(x, (I_flatten, F_in), "float32")
Y = T.match_sparse_buffer(y, (I_flatten, R, F_out), "float32")
with T.iter([T.fuse(R, I), F_out, J, F_in], "SSSRR", "rgcn-hetero-forward") as [
vr, vi, vout, vj, vin
]:
with T.init():
Y[offset[vr] + vi, vr, vout] = 0.
Y[offset[vr] + vi, vr, vout] = Y[offset[vr] + vi, vr, vout] + W[vr, vout, vin] * X[vj, vin]


def test_lower_rgcn_hetero():
mod = tvm.IRModule.from_expr(rgcn_hetero_forward)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
print(mod["main"].script())


if __name__ == "__main__":
test_lower_rgcn_hetero()
68 changes: 68 additions & 0 deletions tests/python/sparsetir/lowered_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,71 @@ def lowered_rgcn_forward(etype: T.handle, w: T.handle, x: T.handle, y: T.handle,
Y_data[vi * feat_size + vout] = T.float32(0)
Y_data[vi * feat_size + vout] = Y_data[vi * feat_size + vout] + W_data[(
E_data[J_indptr[vi] + vj] * feat_size + vout) * feat_size + vin] * X_data[J_indices[J_indptr[vi] + vj] * feat_size + vin]


@T.prim_func
def lowered_fused_reduction_4d_2d(x: T.handle, y: T.handle, indptr_j: T.handle, indptr_k: T.handle, indptr_l: T.handle, n: T.int32, nnz_j: T.int32, nnz_k: T.int32, nnz_l: T.int32) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
X_data = T.match_buffer(x, [nnz_l], dtype="float32")
Y_data = T.match_buffer(y, [nnz_j], dtype="float32")
J_indptr = T.match_buffer(indptr_j, [n + 1], dtype="int32")
K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32")
L_indptr = T.match_buffer(indptr_l, [nnz_k + 1], dtype="int32")
# body
# with T.block("root")
for v_vi, v_vj in T.grid(1, nnz_j):
with T.block("reduction_4d_2d0"):
vi, vj = T.axis.remap("SS", [v_vi, v_vj])
T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1],
L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_j])
T.writes(Y_data[0: nnz_j])
T.block_attr({"sparse": True})
for v_vk in T.serial(K_indptr[vj + 1] - K_indptr[vj]):
with T.block("reduction_4d_2d1"):
vk = T.axis.reduce(K_indptr[vj + 1] - K_indptr[vj], v_vk)
T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1],
L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_j])
T.writes(Y_data[0: nnz_j])
T.block_attr({"sparse": True})
with T.init():
Y_data[vj] = T.float32(0)
for v_vl in T.serial(L_indptr[K_indptr[vj] + vk + 1] - L_indptr[K_indptr[vj] + vk]):
with T.block("reduction_4d_2d2"):
vl = T.axis.reduce(
L_indptr[K_indptr[vj] + vk + 1] - L_indptr[K_indptr[vj] + vk], v_vl)
T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1],
L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_j])
T.writes(Y_data[0: nnz_j])
T.block_attr({"sparse": True})
Y_data[vj] = Y_data[vj] + X_data[L_indptr[K_indptr[vj] + vk] + vl]


@T.prim_func
def lowered_fused_reduction_4d_3d(x: T.handle, y: T.handle, indptr_j: T.handle, indptr_k: T.handle, indptr_l: T.handle, n: T.int32, nnz_j: T.int32, nnz_k: T.int32, nnz_l: T.int32) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
X_data = T.match_buffer(x, [nnz_l], dtype="float32")
Y_data = T.match_buffer(y, [nnz_k], dtype="float32")
J_indptr = T.match_buffer(indptr_j, [n + 1], dtype="int32")
K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32")
L_indptr = T.match_buffer(indptr_l, [nnz_k + 1], dtype="int32")
# body
# with T.block("root")
for v_vi, v_vj, v_vk in T.grid(1, 1, nnz_k):
with T.block("reduction_4d_3d0"):
vi, vj, vk = T.axis.remap("SSS", [v_vi, v_vj, v_vk])
T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1],
L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_k])
T.writes(Y_data[0: nnz_k])
T.block_attr({"sparse": True})
for v_vl in T.serial(L_indptr[vk + 1] - L_indptr[vk]):
with T.block("reduction_4d_3d1"):
vl = T.axis.reduce(L_indptr[vk + 1] - L_indptr[vk], v_vl)
T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1],
L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_k])
T.writes(Y_data[0: nnz_k])
T.block_attr({"sparse": True})
with T.init():
Y_data[vk] = T.float32(0)
Y_data[vk] = Y_data[vk] + X_data[L_indptr[vk] + vl]
48 changes: 48 additions & 0 deletions tests/python/sparsetir/sparse_tir_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,54 @@ def square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.
B[vi] = B[vi] + A[vi, vj, vk]


@T.prim_func
def fused_reduction_4d_2d(
x: T.handle,
y: T.handle,
indptr_j: T.handle,
indptr_k: T.handle,
indptr_l: T.handle,
n: T.int32,
nnz_j: T.int32,
nnz_k: T.int32,
nnz_l: T.int32) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
I = T.dense_fixed(n)
J = T.dense_variable(I, (32768, nnz_j), indptr_j, "int32")
K = T.dense_variable(J, (32768, nnz_k), indptr_k, "int32")
L = T.dense_variable(K, (32768, nnz_l), indptr_l, "int32")
X = T.match_sparse_buffer(x, (I, J, K, L), "float32")
Y = T.match_sparse_buffer(y, (I, J), "float32")
with T.iter([T.fuse(I, J), K, L], "SSRR", "reduction_4d_2d") as [vi, vj, vk, vl]:
with T.init():
Y[vi, vj] = 0.0
Y[vi, vj] = Y[vi, vj] + X[vi, vj, vk, vl]


@T.prim_func
def fused_reduction_4d_3d(
x: T.handle,
y: T.handle,
indptr_j: T.handle,
indptr_k: T.handle,
indptr_l: T.handle,
n: T.int32,
nnz_j: T.int32,
nnz_k: T.int32,
nnz_l: T.int32) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
I = T.dense_fixed(n)
J = T.dense_variable(I, (32768, nnz_j), indptr_j, "int32")
K = T.dense_variable(J, (32768, nnz_k), indptr_k, "int32")
L = T.dense_variable(K, (32768, nnz_l), indptr_l, "int32")
X = T.match_sparse_buffer(x, (I, J, K, L), "float32")
Y = T.match_sparse_buffer(y, (I, J, K), "float32")
with T.iter([T.fuse(I, J, K), L], "SSSR", "reduction_4d_3d") as [vi, vj, vk, vl]:
with T.init():
Y[vi, vj, vk] = 0.0
Y[vi, vj, vk] = Y[vi, vj, vk] + X[vi, vj, vk, vl]


@T.prim_func
def rgcn_forward(
etype: T.handle,
Expand Down
14 changes: 11 additions & 3 deletions tests/python/sparsetir/test_tir_sparse_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,18 @@ def test_csr_element_wise():
tvm.ir.assert_structural_equal(mod["main"], lowered_csr_element_wise, True)


@pytest.mark.skip(reason="Under implementation")
def test_bmm():
mod = tvm.IRModule.from_expr(bmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_bmm)


@pytest.mark.skip(reason="Under implementation")
def test_sddmm():
mod = tvm.IRModule.from_expr(sddmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_sddmm)


@pytest.mark.skip(reason="Under implementation")
def test_fused_sddmm():
mod = tvm.IRModule.from_expr(fused_sddmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
Expand All @@ -96,6 +93,16 @@ def test_square_sum_two_K():
tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum_two_K, True)


def test_fused_reduction():
mod = tvm.IRModule.from_expr(fused_reduction_4d_2d)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_fused_reduction_4d_2d, True)

mod = tvm.IRModule.from_expr(fused_reduction_4d_3d)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_fused_reduction_4d_3d, True)


if __name__ == "__main__":
test_csrmm()
test_csrmm_dense_iter()
Expand All @@ -109,3 +116,4 @@ def test_square_sum_two_K():
test_bmm()
test_square_sum()
test_square_sum_two_K()
test_fused_reduction()
1 change: 1 addition & 0 deletions tests/python/sparsetir/test_tir_sparse_tensorize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# TODO