diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 59938d0af7a3..e0caa2941d90 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -935,7 +935,7 @@ def dense_variable( ) length, nnz = shape - indptr_len = parent_axis.length + 1 + indptr_len = parent_axis.nnz + 1 indptr_buf = tvm.tir.decl_buffer( (indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span ) diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index 7338a59066ee..84525a97745a 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -529,7 +529,7 @@ class IndexTransformer : public StmtExprMutator { const Optional& loop_var = axis2loop_var.Get(axis->GetParentAxis().value()); CHECK(loop_var.defined()) << "ValueError: The parent axis of " << axis - << "does not appear in the sparse block"; + << " does not appear in the sparse block"; if (LoopVarAppears(loop_var.value())) { return true; @@ -559,6 +559,10 @@ class IndexTransformer : public StmtExprMutator { for (const SpIterVar& sp_iter_var : sp_block->sp_iter_vars) { Var loop_var("v_" + sp_iter_var->var->name_hint); var_map.Set(sp_iter_var->var, loop_var); + if (auto fused_axis = sp_iter_var->axis.as()) { + // handle the special case of fused_axis + axis2loop_var.Set(fused_axis->group[fused_axis->index], loop_var); + } axis2loop_var.Set(sp_iter_var->axis, loop_var); } diff --git a/tests/python/sparsetir/bench_rgcn_new.py b/tests/python/sparsetir/bench_rgcn_new.py new file mode 100644 index 000000000000..b46c3a381d33 --- /dev/null +++ b/tests/python/sparsetir/bench_rgcn_new.py @@ -0,0 +1,54 @@ +from dgl.heterograph import DGLHeteroGraph +import tvm +import tvm.testing +import tvm.tir as tir +import scipy.sparse as sp +import numpy as np +import dgl +import dgl.function as fn +import torch as th +from tvm.script import tir as T +from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset + + +@T.prim_func +def rgcn_hetero_forward( + offset_ntype: T.handle, + w: T.handle, + x: T.handle, + y: T.handle, + indptr_i: T.handle, + indptr_j: T.handle, + indices_j: T.handle, + n: T.int32, + r: T.int32, + feat_size: T.int32, + nnz_i: T.int32, + nnz_j: T.int32 +): + I_flatten = T.dense_fixed(n) + R = T.dense_fixed(r) + I = T.dense_variable(R, (n, nnz_i), indptr_i, "int32") + J = T.sparse_variable(I, (n, nnz_j), (indptr_j, indices_j), "int32") + F_in = T.dense_fixed(feat_size) + F_out = T.dense_fixed(feat_size) + offset = T.match_sparse_buffer(offset_ntype, (R,), "int32") + W = T.match_sparse_buffer(w, (R, F_out, F_in), "float32") + X = T.match_sparse_buffer(x, (I_flatten, F_in), "float32") + Y = T.match_sparse_buffer(y, (I_flatten, R, F_out), "float32") + with T.iter([T.fuse(R, I), F_out, J, F_in], "SSSRR", "rgcn-hetero-forward") as [ + vr, vi, vout, vj, vin + ]: + with T.init(): + Y[offset[vr] + vi, vr, vout] = 0. + Y[offset[vr] + vi, vr, vout] = Y[offset[vr] + vi, vr, vout] + W[vr, vout, vin] * X[vj, vin] + + +def test_lower_rgcn_hetero(): + mod = tvm.IRModule.from_expr(rgcn_hetero_forward) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + print(mod["main"].script()) + + +if __name__ == "__main__": + test_lower_rgcn_hetero() diff --git a/tests/python/sparsetir/lowered_tir.py b/tests/python/sparsetir/lowered_tir.py index 360a709efa5c..31c076e4bb29 100644 --- a/tests/python/sparsetir/lowered_tir.py +++ b/tests/python/sparsetir/lowered_tir.py @@ -394,3 +394,71 @@ def lowered_rgcn_forward(etype: T.handle, w: T.handle, x: T.handle, y: T.handle, Y_data[vi * feat_size + vout] = T.float32(0) Y_data[vi * feat_size + vout] = Y_data[vi * feat_size + vout] + W_data[( E_data[J_indptr[vi] + vj] * feat_size + vout) * feat_size + vin] * X_data[J_indices[J_indptr[vi] + vj] * feat_size + vin] + + +@T.prim_func +def lowered_fused_reduction_4d_2d(x: T.handle, y: T.handle, indptr_j: T.handle, indptr_k: T.handle, indptr_l: T.handle, n: T.int32, nnz_j: T.int32, nnz_k: T.int32, nnz_l: T.int32) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + X_data = T.match_buffer(x, [nnz_l], dtype="float32") + Y_data = T.match_buffer(y, [nnz_j], dtype="float32") + J_indptr = T.match_buffer(indptr_j, [n + 1], dtype="int32") + K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32") + L_indptr = T.match_buffer(indptr_l, [nnz_k + 1], dtype="int32") + # body + # with T.block("root") + for v_vi, v_vj in T.grid(1, nnz_j): + with T.block("reduction_4d_2d0"): + vi, vj = T.axis.remap("SS", [v_vi, v_vj]) + T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1], + L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_j]) + T.writes(Y_data[0: nnz_j]) + T.block_attr({"sparse": True}) + for v_vk in T.serial(K_indptr[vj + 1] - K_indptr[vj]): + with T.block("reduction_4d_2d1"): + vk = T.axis.reduce(K_indptr[vj + 1] - K_indptr[vj], v_vk) + T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1], + L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_j]) + T.writes(Y_data[0: nnz_j]) + T.block_attr({"sparse": True}) + with T.init(): + Y_data[vj] = T.float32(0) + for v_vl in T.serial(L_indptr[K_indptr[vj] + vk + 1] - L_indptr[K_indptr[vj] + vk]): + with T.block("reduction_4d_2d2"): + vl = T.axis.reduce( + L_indptr[K_indptr[vj] + vk + 1] - L_indptr[K_indptr[vj] + vk], v_vl) + T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1], + L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_j]) + T.writes(Y_data[0: nnz_j]) + T.block_attr({"sparse": True}) + Y_data[vj] = Y_data[vj] + X_data[L_indptr[K_indptr[vj] + vk] + vl] + + +@T.prim_func +def lowered_fused_reduction_4d_3d(x: T.handle, y: T.handle, indptr_j: T.handle, indptr_k: T.handle, indptr_l: T.handle, n: T.int32, nnz_j: T.int32, nnz_k: T.int32, nnz_l: T.int32) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + X_data = T.match_buffer(x, [nnz_l], dtype="float32") + Y_data = T.match_buffer(y, [nnz_k], dtype="float32") + J_indptr = T.match_buffer(indptr_j, [n + 1], dtype="int32") + K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32") + L_indptr = T.match_buffer(indptr_l, [nnz_k + 1], dtype="int32") + # body + # with T.block("root") + for v_vi, v_vj, v_vk in T.grid(1, 1, nnz_k): + with T.block("reduction_4d_3d0"): + vi, vj, vk = T.axis.remap("SSS", [v_vi, v_vj, v_vk]) + T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1], + L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_k]) + T.writes(Y_data[0: nnz_k]) + T.block_attr({"sparse": True}) + for v_vl in T.serial(L_indptr[vk + 1] - L_indptr[vk]): + with T.block("reduction_4d_3d1"): + vl = T.axis.reduce(L_indptr[vk + 1] - L_indptr[vk], v_vl) + T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1], + L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_k]) + T.writes(Y_data[0: nnz_k]) + T.block_attr({"sparse": True}) + with T.init(): + Y_data[vk] = T.float32(0) + Y_data[vk] = Y_data[vk] + X_data[L_indptr[vk] + vl] diff --git a/tests/python/sparsetir/sparse_tir_scripts.py b/tests/python/sparsetir/sparse_tir_scripts.py index 778186b48942..1c473026bc5d 100644 --- a/tests/python/sparsetir/sparse_tir_scripts.py +++ b/tests/python/sparsetir/sparse_tir_scripts.py @@ -282,6 +282,54 @@ def square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T. B[vi] = B[vi] + A[vi, vj, vk] +@T.prim_func +def fused_reduction_4d_2d( + x: T.handle, + y: T.handle, + indptr_j: T.handle, + indptr_k: T.handle, + indptr_l: T.handle, + n: T.int32, + nnz_j: T.int32, + nnz_k: T.int32, + nnz_l: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(n) + J = T.dense_variable(I, (32768, nnz_j), indptr_j, "int32") + K = T.dense_variable(J, (32768, nnz_k), indptr_k, "int32") + L = T.dense_variable(K, (32768, nnz_l), indptr_l, "int32") + X = T.match_sparse_buffer(x, (I, J, K, L), "float32") + Y = T.match_sparse_buffer(y, (I, J), "float32") + with T.iter([T.fuse(I, J), K, L], "SSRR", "reduction_4d_2d") as [vi, vj, vk, vl]: + with T.init(): + Y[vi, vj] = 0.0 + Y[vi, vj] = Y[vi, vj] + X[vi, vj, vk, vl] + + +@T.prim_func +def fused_reduction_4d_3d( + x: T.handle, + y: T.handle, + indptr_j: T.handle, + indptr_k: T.handle, + indptr_l: T.handle, + n: T.int32, + nnz_j: T.int32, + nnz_k: T.int32, + nnz_l: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(n) + J = T.dense_variable(I, (32768, nnz_j), indptr_j, "int32") + K = T.dense_variable(J, (32768, nnz_k), indptr_k, "int32") + L = T.dense_variable(K, (32768, nnz_l), indptr_l, "int32") + X = T.match_sparse_buffer(x, (I, J, K, L), "float32") + Y = T.match_sparse_buffer(y, (I, J, K), "float32") + with T.iter([T.fuse(I, J, K), L], "SSSR", "reduction_4d_3d") as [vi, vj, vk, vl]: + with T.init(): + Y[vi, vj, vk] = 0.0 + Y[vi, vj, vk] = Y[vi, vj, vk] + X[vi, vj, vk, vl] + + @T.prim_func def rgcn_forward( etype: T.handle, diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py index e33735cd2dda..3a7906451e68 100644 --- a/tests/python/sparsetir/test_tir_sparse_lower.py +++ b/tests/python/sparsetir/test_tir_sparse_lower.py @@ -63,21 +63,18 @@ def test_csr_element_wise(): tvm.ir.assert_structural_equal(mod["main"], lowered_csr_element_wise, True) -@pytest.mark.skip(reason="Under implementation") def test_bmm(): mod = tvm.IRModule.from_expr(bmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_bmm) -@pytest.mark.skip(reason="Under implementation") def test_sddmm(): mod = tvm.IRModule.from_expr(sddmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_sddmm) -@pytest.mark.skip(reason="Under implementation") def test_fused_sddmm(): mod = tvm.IRModule.from_expr(fused_sddmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) @@ -96,6 +93,16 @@ def test_square_sum_two_K(): tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum_two_K, True) +def test_fused_reduction(): + mod = tvm.IRModule.from_expr(fused_reduction_4d_2d) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_fused_reduction_4d_2d, True) + + mod = tvm.IRModule.from_expr(fused_reduction_4d_3d) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_fused_reduction_4d_3d, True) + + if __name__ == "__main__": test_csrmm() test_csrmm_dense_iter() @@ -109,3 +116,4 @@ def test_square_sum_two_K(): test_bmm() test_square_sum() test_square_sum_two_K() + test_fused_reduction() diff --git a/tests/python/sparsetir/test_tir_sparse_tensorize.py b/tests/python/sparsetir/test_tir_sparse_tensorize.py index e69de29bb2d1..f87f5c14cbbd 100644 --- a/tests/python/sparsetir/test_tir_sparse_tensorize.py +++ b/tests/python/sparsetir/test_tir_sparse_tensorize.py @@ -0,0 +1 @@ +# TODO \ No newline at end of file