From bce0f8eb8dcaecb0270ecab6cef171a4c46dff9e Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 22 Nov 2021 18:35:40 -0800 Subject: [PATCH] Add "sparse" block attribute. (#26) --- src/tir/schedule/analysis/analysis.cc | 13 +++++++++---- src/tir/transforms/lower_sparse_tir.cc | 4 +++- .../python/sparsetir/test_tir_sparse_correctness.py | 12 +++++++++--- tests/python/sparsetir/test_tir_sparse_lower.py | 5 +++++ 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index c8f28968fd46..8c05389455b2 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -163,12 +163,17 @@ Definition of a scope that is a stage pipeline: !IsReductionBlock(self, block_sref, scope_root_sref)) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); // NOTE(Zihao): check if the block has atomic attribute. - auto&& it = block->annotations.find("atomic"); + auto&& it_atomic = block->annotations.find("atomic"); bool is_atomic = false; - if (it != block->annotations.end()) { - is_atomic = ((*it).second).as()->value; + if (it_atomic != block->annotations.end()) { + is_atomic = ((*it_atomic).second).as()->value; } - if (!is_atomic) { + auto&& it_sparse = block->annotations.find("sparse"); + bool is_sparse = false; + if (it_sparse != block->annotations.end()) { + is_sparse = ((*it_sparse).second).as()->value; + } + if (!is_sparse && !is_atomic) { throw NotCompactDataFlowError(self->mod, GetRef(scope_root_subtree->stmt), GetRef(block)); } diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index cca219ed492b..47e53707a852 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -308,8 +308,10 @@ class IndexTransformer : public StmtExprMutator { GenerateReadWriteRegions(sp_block, &reads, &writes); // Step 5. Create the block and block-realize + Map mapping; + mapping.Set("sparse", Bool(true)); Block block(block_iters, std::move(reads), std::move(writes), sp_block->name, std::move(body), - std::move(init)); + std::move(init), {}, {}, std::move(mapping)); BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block)); // Step 6. Create outer loops and the block binding. diff --git a/tests/python/sparsetir/test_tir_sparse_correctness.py b/tests/python/sparsetir/test_tir_sparse_correctness.py index b412c62ce1b2..f56fa1566246 100644 --- a/tests/python/sparsetir/test_tir_sparse_correctness.py +++ b/tests/python/sparsetir/test_tir_sparse_correctness.py @@ -31,6 +31,7 @@ def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.ha B = T.match_sparse_buffer(b, (T.to_dense(J), K), n * k, "float32") C = T.match_sparse_buffer(c, (I, K), m * k, "float32") with T.iter([T.cord(I), T.cord(J), T.cord(K)], "SRS", "csrmm") as [vi, vj, vk]: + T.block_attr({"sparse": True}) with T.init(): C[vi, vk] = 0.0 C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] @@ -51,6 +52,7 @@ def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: C[vi * K + vk] = 0. for j in T.serial(0, A_indptr[vi + 1] - A_indptr[vi]): with T.block("spmm_inner"): + T.block_attr({"sparse": True}) vj = T.axis.R(NNZ, j + A_indptr[vi]) C[vi * K + vk] = C[vi * K + vk] + \ A_data[vj] * B[A_indices[vj] * K + vk] @@ -71,6 +73,7 @@ def bsrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: C[(vio * BLOCK_SIZE + vii) * K + vk] = 0. for jo in T.serial(0, A_indptr[vio + 1] - A_indptr[vio]): with T.block("spmm_inner"): + T.block_attr({"sparse": True}) vjo = T.axis.R(NNZB, jo + A_indptr[vio]) C[(vio * BLOCK_SIZE + vii) * K + vk] = C[(vio * BLOCK_SIZE + vii) * K + vk] + A_data[( vjo * BLOCK_SIZE + vii) * BLOCK_SIZE + vji] * B[(A_indices[vjo] * BLOCK_SIZE + vji) * K + vk] @@ -85,6 +88,7 @@ def ellmm_tir(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, M: T.int A_indices = T.match_buffer(indices, (M * NNZ_COLS,), "int32") for i, j, k in T.grid(M, NNZ_COLS, K): with T.block("spmm"): + T.block_attr({"sparse": True}) vi, vj, vk = T.axis.remap("SRS", [i, j, k]) with T.init(): C[vi * K + vk] = 0. @@ -102,6 +106,7 @@ def sddmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: C_indices = T.match_buffer(indices, (NNZ,), "int32") for ij, k in T.grid(NNZ, K): with T.block("sddmm"): + T.block_attr({"sparse": True}) vij, vk = T.axis.remap("SR", [ij, k]) T.reads([A[0: M * K], B[0: N * K], C_data[vij], C_indices[vij], C_indptr[0: M + 1]]) T.writes([C_data[vij]]) @@ -262,10 +267,10 @@ def test_sddmm(): ) blk = sch.get_block("sddmm") ij, k = sch.get_loops(blk) - #sch.decompose_reduction(blk, ij) + # TODO(zihao): fix the behavior in the future. + # sch.decompose_reduction(blk, ij) sch.bind(ij, "blockIdx.x") - ko, ki = sch.split(k, [None, 1]) - sch.bind(ki, "threadIdx.x") + sch.bind(k, "threadIdx.x") # convert numpy tensor to tvm ndarray C_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0)) @@ -276,6 +281,7 @@ def test_sddmm(): # build function f = tvm.build(sch.mod['main'], target="cuda") + # print(f.imported_modules[0].get_source()) f(X_nd, Y_nd, C_data, C_indptr, C_indices) # assertion diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py index 684886b4cd76..70b2a83f6b82 100644 --- a/tests/python/sparsetir/test_tir_sparse_lower.py +++ b/tests/python/sparsetir/test_tir_sparse_lower.py @@ -69,6 +69,7 @@ def lowered_csrmm( for v_vi in T.serial(0, n): for v_vj, v_vk in T.grid(J_indptr[v_vi + 1] - J_indptr[v_vi], k): with T.block("csrmm"): + T.block_attr({"sparse": True}) vi, vj, vk = T.axis.remap("SRS", [v_vi, v_vj, v_vk]) T.reads( [ @@ -125,6 +126,7 @@ def lowered_csr_reduce( for v_vi in T.serial(0, n): for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): with T.block("csr_reduce"): + T.block_attr({"sparse": True}) vi, vj = T.axis.remap("SR", [v_vi, v_vj]) T.reads([J_indptr[0 : n + 1], J_indices[0:nnz], A_data[0:nnz], B_data[0:n]]) T.writes([B_data[0:n]]) @@ -190,6 +192,7 @@ def lowered_bsrmm( J_indptr[v_vi + 1] - J_indptr[v_vi], blk, blk, feat_size ): with T.block("bsrmm"): + T.block_attr({"sparse": True}) vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf]) T.reads( [ @@ -263,6 +266,7 @@ def lowered_ellpack_mm( J_indices = T.match_buffer(indices, [nnz], dtype="int32") for v_vi, v_vj, v_vbi, v_vbj, v_vf in T.grid(nb, col, blk, blk, feat_size): with T.block("bsrmm"): + T.block_attr({"sparse": True}) vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf]) T.reads( [ @@ -359,6 +363,7 @@ def lowered_csr_element_wise( for v_vi in T.serial(0, m): for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): with T.block("csr_element_wise"): + T.block_attr({"sparse": True}) vi, vj = T.axis.remap("SS", [v_vi, v_vj]) T.reads([J_indptr[0 : m + 1], J_indices[0:nnz], A_data[0:nnz]]) T.writes([B_data[0:nnz]])