From f3721612f62e61529b445ccceae98e5eaacd1a71 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Fri, 26 Nov 2021 14:30:04 -0800 Subject: [PATCH] Axis Dependency Tree aware code-gen and bmm example (#28) * upd * upd * upd * upd * upd * upd * upd * upd * remove redundancy * fix * upd * upd --- include/tvm/tir/sparse.h | 22 +- include/tvm/tir/transform.h | 3 +- python/tvm/tir/transform/transform.py | 10 +- src/tir/ir/sparse.cc | 9 +- src/tir/transforms/lower_sparse_tir.cc | 149 ++++++++--- .../sparsetir/test_tir_sparse_correctness.py | 117 +++++++- .../python/sparsetir/test_tir_sparse_lower.py | 249 ++++++++---------- 7 files changed, 373 insertions(+), 186 deletions(-) diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index ec8c1a125b3a..b725af1a86fd 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -47,6 +47,8 @@ class AxisNode : public Object { String GetName() const { return name; } PrimExpr GetLength() const { return length; } DataType GetIndexType() const { return length->dtype; } + + virtual bool is_fixed() const = 0; static constexpr const char* _type_key = "tir.sparse.Axis"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -141,6 +143,10 @@ class DenseFixedAxisNode : public DenseAxisNode { hash_reduce(from_sparse); } + bool is_fixed() const { + return true; + } + static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode); }; @@ -177,6 +183,10 @@ class DenseVariableAxisNode : public DenseAxisNode { hash_reduce(indptr); } + bool is_fixed() const { + return false; + } + static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode); }; @@ -220,6 +230,10 @@ class SparseFixedAxisNode : public SparseAxisNode { hash_reduce(nnz_cols); } + bool is_fixed() const { + return true; + } + static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode); }; @@ -262,6 +276,10 @@ class SparseVariableAxisNode : public SparseAxisNode { hash_reduce(indices); } + bool is_fixed() const { + return false; + } + static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode); }; @@ -283,9 +301,9 @@ class SparseVariableAxis : public SparseAxis { class AxisTreeNode : public Object { public: // unordered map that stores the parent relationship between axes. - Map> parent; + Map parent; // unordered map that stores the children relationship between axes. - Map, Array> children; + Map> children; void VisitAttrs(AttrVisitor* v) { v->Visit("parent", &parent); diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 99a1f00c6922..286eba616deb 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -494,9 +494,10 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); /*! * \brief Lower SparseTIR to TIR. + * \param axis_tree The axis dependency tree. * \return The pass. */ -TVM_DLL Pass LowerSparseTIR(); +TVM_DLL Pass LowerSparseTIR(AxisTree axis_tree); } // namespace transform } // namespace tir diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 0073503e4fc3..96ae275dc156 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -19,6 +19,7 @@ from typing import Optional from . import _ffi_api from . import function_pass as _fpass +from ..sparse import AxisTree def Apply(ftransform): @@ -751,12 +752,17 @@ def ConvertForLoopsToSerial(): return _ffi_api.ConvertForLoopsToSerial() # type: ignore -def LowerSparseTIR(): +def LowerSparseTIR(axis_tree: AxisTree): """Lower SparseTIR to TIR + Parameters + ---------- + axis_tree : AxisTree + The axis dependency tree. + Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerSparseTIR() # type: ignore + return _ffi_api.LowerSparseTIR(axis_tree) # type: ignore diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index 102eb84769a0..2ebd3fdb282f 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -146,12 +146,15 @@ AxisTree::AxisTree(Array axis_names, Array> axis_parent "axis_parent_names " "array."; ObjectPtr node = make_object(); - Map> parent; - Map, Array> children; + Map parent; + Map> children; for (size_t i = 0; i < axis_names.size(); i++) { // update parent map & children map String axis_name = axis_names[i]; - Optional parent_name = axis_parent_names[i]; + String parent_name("root"); + if (axis_parent_names[i].defined()) { + parent_name = axis_parent_names[i].value(); + } parent.Set(axis_name, parent_name); auto it = children.find(parent_name); diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index 47e53707a852..3e02c9a59daa 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -26,6 +26,8 @@ #include #include +#include +#include #include #include "../schedule/analysis.h" @@ -87,8 +89,8 @@ Map UpdateBufferMap(PrimFunc f) { */ class IndexTransformer : public StmtExprMutator { public: - explicit IndexTransformer(AccessAndDependencyCollector collector) - : collector_(std::move(collector)) {} + explicit IndexTransformer(AccessAndDependencyCollector collector, AxisTree axis_tree) + : collector_(std::move(collector)), axis_tree_(std::move(axis_tree)) {} private: /*! @@ -281,43 +283,124 @@ class IndexTransformer : public StmtExprMutator { sp_block->init.defined() ? VisitStmt(sp_block->init.value()) : Optional(NullOpt); Stmt body = VisitStmt(sp_block->body); - // Step 2. Create the new outer loop vars. - Array loop_vars; + // Step 2. Create the new loop vars. std::unordered_map var_map; - loop_vars.reserve(n_iter); + Array all_loop_vars; var_map.reserve(n_iter); for (const SpIterVar& sp_iter : sp_block->sp_iter_vars) { Var loop_var("v_" + sp_iter->var->name_hint); - loop_vars.push_back(loop_var); + all_loop_vars.push_back(loop_var); var_map[sp_iter->var.get()] = loop_var; } - // Step 3. Create block iters and iter bindings. - Array block_iters; - Array iter_bindings; - block_iters.reserve(n_iter); - iter_bindings.reserve(n_iter); - for (int i = 0; i < n_iter; ++i) { - block_iters.push_back(SpIterVarToIterVar(sp_block->sp_iter_vars[i], var_map)); - iter_bindings.push_back(loop_vars[i]); - } + // Step 3. Collet block iters and iter bindings. + std::set in_stack; + in_stack.insert("root"); + /* A stack that stores block itervars in each block. */ + std::stack> block_iters_st; + /* A stack that stores itervar bindings in each block. */ + std::stack> iter_bindings_st; + /* A stack that stores generated loop vars in each block. */ + std::stack> loop_vars_st; + /* A stack that stores whether to place init block in each block. */ + std::stack place_init_st; + /* An indicator that records whether init block has been set. */ + bool init_set = false; + do { + /* Block itervars of current block. */ + Array block_iters; + /* Itervar bindings of current block. */ + Array iter_bindings; + /* Axis names of current block. */ + Array axis_names; + /* Generated loop vars of current block. */ + Array loop_vars; + /* An indicator that records whether there is reduction axis in current block. */ + bool has_reduction_var = false; + for (int i = 0; i < n_iter; ++i) { + SpIterVar sp_it_var = sp_block->sp_iter_vars[i]; + String axis_name = sp_it_var->axis->name; + auto&& parent_axis = axis_tree_->parent.Get(axis_name); + CHECK(parent_axis.defined()) << "Sparse IterVar not defined in Axis Tree."; + String parent_axis_name = parent_axis.value(); + bool is_fixed_axis = sp_it_var->axis->is_fixed(); + /* Add itervar to current block when + * - it's not used yet (not in stack) and + * - it's parent axis was used in outer blocks or + * - it's an iterator to a fixed axis. + */ + if ((is_fixed_axis || in_stack.find(parent_axis_name) != in_stack.end()) && + in_stack.find(axis_name) == in_stack.end()) { + loop_vars.push_back(all_loop_vars[i]); + axis_names.push_back(std::move(axis_name)); + block_iters.push_back(SpIterVarToIterVar(sp_it_var, var_map)); + iter_bindings.push_back(all_loop_vars[i]); + has_reduction_var |= sp_it_var->is_reduction; + } + } + + /* Tag axes in current block as "in-stack". */ + for (const String&& axis_name : axis_names) { + in_stack.insert(std::move(axis_name)); + } + + /* Update stack. */ + if (!block_iters.empty()) { + block_iters_st.push(std::move(block_iters)); + iter_bindings_st.push(std::move(iter_bindings)); + loop_vars_st.push(std::move(loop_vars)); + if (init_set) { + place_init_st.push(false); + } else { + place_init_st.push(has_reduction_var); + init_set |= has_reduction_var; + } + } else { + break; + } + } while (true); // Step 4. Generate the read-region and write-retion of the block. Array reads{nullptr}; Array writes{nullptr}; GenerateReadWriteRegions(sp_block, &reads, &writes); - // Step 5. Create the block and block-realize - Map mapping; - mapping.Set("sparse", Bool(true)); - Block block(block_iters, std::move(reads), std::move(writes), sp_block->name, std::move(body), - std::move(init), {}, {}, std::move(mapping)); - BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block)); - - // Step 6. Create outer loops and the block binding. - Stmt loop = GenerateLoops(std::move(block_realize), block_iters, loop_vars); + // Step 5. Generate nested blocks and loops from innermost to outermost. + int blk_counter = 0; + while (!block_iters_st.empty()) { + Array block_iters = std::move(block_iters_st.top()); + Array iter_bindings = std::move(iter_bindings_st.top()); + Array loop_vars = std::move(loop_vars_st.top()); + bool place_init = place_init_st.top(); + block_iters_st.pop(); + iter_bindings_st.pop(); + loop_vars_st.pop(); + place_init_st.pop(); + + Map mapping; + mapping.Set("sparse", Bool(true)); + String blk_name_hint = sp_block->name; + if (blk_counter != 0) { + blk_name_hint = blk_name_hint + "_" + std::to_string(blk_counter); + } + Block block(/*iter_vars=*/block_iters, + /*reads=*/reads, + /*writes=*/writes, + /*name_hint=*/blk_name_hint, + /*body=*/std::move(body), + /*init=*/place_init ? std::move(init) : NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/std::move(mapping), + /*span=*/sp_block->span); + BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block)); + // Generate outer loop and the block binding. + Stmt loop = GenerateLoops(std::move(block_realize), block_iters, loop_vars); + body = loop; + blk_counter += 1; + } - return loop; + return body; } /*! @@ -380,9 +463,10 @@ class IndexTransformer : public StmtExprMutator { } /*! - * \brief generated nested for loops for sparse block. + * \brief generated nested for-loops for sparse block. * \param block_iters The iterators defined in sparse blocks. * \param loop_vars The loop variables binded with block iterators. + * \return The outermost loop. */ Stmt GenerateLoops(Stmt body, const Array& block_iters, const Array& loop_vars) { int n_iter = static_cast(block_iters.size()); @@ -394,6 +478,7 @@ class IndexTransformer : public StmtExprMutator { } AccessAndDependencyCollector collector_; + AxisTree axis_tree_; arith::Analyzer ana_; std::unordered_set buffer_read_; std::unordered_set buffer_write_; @@ -411,11 +496,12 @@ Stmt WrapWithRootBlock(Stmt body) { } /*! - * \brief Rewrite the given primitive function + * \brief Rewrite the given primitive function. + * \param axis_tree The axis dependency tree. * \param f The Sparse-TIR primitive function to lower. * \return lowered primitive function in TIR. */ -PrimFunc LowerSparseTIR(PrimFunc f) { +PrimFunc LowerSparseTIR(AxisTree axis_tree, PrimFunc f) { // Only apply this pass to TIR that is not from TE schedules if (!IsFromLegacyTESchedule(f)) { PrimFuncNode* fptr = f.CopyOnWrite(); @@ -425,7 +511,7 @@ PrimFunc LowerSparseTIR(PrimFunc f) { AccessAndDependencyCollector collector; collector.Collect(f->body); // Step 3. Lower indices. - fptr->body = IndexTransformer(collector)(std::move(f->body)); + fptr->body = IndexTransformer(collector, axis_tree)(std::move(f->body)); // Step 4. Wrap the function body with a root block. fptr->body = WrapWithRootBlock(std::move(fptr->body)); return f; @@ -438,10 +524,11 @@ namespace transform { /*! * \brief The lowering pass from TIR to Sparse TIR. + * \param axis_tree The axis dependency tree. */ -Pass LowerSparseTIR() { +Pass LowerSparseTIR(AxisTree axis_tree) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return LowerSparseTIR(std::move(f)); + return LowerSparseTIR(std::move(axis_tree), std::move(f)); }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerSparseTIR", {}); } diff --git a/tests/python/sparsetir/test_tir_sparse_correctness.py b/tests/python/sparsetir/test_tir_sparse_correctness.py index 6e2f2683ea8b..63bcce46a4b2 100644 --- a/tests/python/sparsetir/test_tir_sparse_correctness.py +++ b/tests/python/sparsetir/test_tir_sparse_correctness.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from ctypes import c_float import tvm +import tvm.testing from tvm.runtime.ndarray import device import tvm.tir as tir import scipy.sparse as sp @@ -113,7 +115,40 @@ def sddmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: with T.init(): C_data[vij] = 0. C_data[vij] = C_data[vij] + \ - A[T.lower_bound(C_indptr.data, vij, 0, M + 1) * K + vk] * B[C_indices[vij] * K + vk] + A[(T.upper_bound(C_indptr.data, vij, 0, M + 1) - 1) * K + vk] * B[C_indices[vij] * K + vk] + + +@T.prim_func +def bmm_tir(a: T.handle, b: T.handle, c: T.handle, + indptr_i: T.handle, indptr_j: T.handle, indptr_k: T.handle, + indptr_ij: T.handle, indptr_jk: T.handle, indptr_ik: T.handle, + BATCH: T.int32, + NNZ_IJ: T.int32, NNZ_JK: T.int32, NNZ_IK: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, (NNZ_IJ,), "float32") + B = T.match_buffer(b, (NNZ_JK,), "float32") + C = T.match_buffer(c, (NNZ_IK,), "float32") + indptr_I = T.match_buffer(indptr_i, (BATCH + 1,), "int32") + indptr_J = T.match_buffer(indptr_j, (BATCH + 1,), "int32") + indptr_K = T.match_buffer(indptr_k, (BATCH + 1,), "int32") + indptr_IJ = T.match_buffer(indptr_ij, (BATCH + 1,), "int32") + indptr_JK = T.match_buffer(indptr_jk, (BATCH + 1,), "int32") + indptr_IK = T.match_buffer(indptr_ik, (BATCH + 1,), "int32") + for b in T.grid(BATCH): + with T.block("bmm_outer"): + T.block_attr({"sparse": True}) + vb = T.axis.S(BATCH, b) + with T.init(): + T.evaluate(1) + for i, j, k in T.grid(indptr_I[vb + 1] - indptr_I[vb], indptr_J[vb + 1] - indptr_J[vb], indptr_K[vb + 1] - indptr_K[vb]): + with T.block("bmm_inner"): + T.block_attr({"sparse": True}) + vi, vj, vk = T.axis.remap("SRS", [i, j, k]) + with T.init(): + C[indptr_IK[vb] + vi * (indptr_K[vb + 1] - indptr_K[vb]) + vk] = 0. + C[indptr_IK[vb] + vi * (indptr_K[vb + 1] - indptr_K[vb]) + vk] = C[indptr_IK[vb] + vi * (indptr_K[vb + 1] - indptr_K[vb]) + vk] +\ + A[indptr_IJ[vb] + vi * (indptr_J[vb + 1] - indptr_J[vb]) + vj] * \ + B[indptr_JK[vb] + vj * (indptr_K[vb + 1] - indptr_K[vb]) + vk] def test_csrmm(): @@ -151,7 +186,7 @@ def test_csrmm(): f(A_data, X_nd, Y_nd, A_indptr, A_indices) # assertion - assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy()) + tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5) def test_bsrmm(): @@ -199,7 +234,7 @@ def test_bsrmm(): f(A_data, X_nd, Y_nd, A_indptr, A_indices) # assertion - assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy()) + tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5) def test_ellmm(): @@ -240,7 +275,7 @@ def test_ellmm(): f(A_data, X_nd, Y_nd, A_indices) # assertion - assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy()) + tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5) def test_sddmm(): @@ -269,7 +304,6 @@ def test_sddmm(): ij, k = sch.get_loops(blk) sch.bind(ij, "blockIdx.x") sch.bind(k, "threadIdx.x") - sch.decompose_reduction(blk, k) # convert numpy tensor to tvm ndarray C_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0)) @@ -280,16 +314,81 @@ def test_sddmm(): # build function f = tvm.build(sch.mod['main'], target="cuda") - # print(f.imported_modules[0].get_source()) f(X_nd, Y_nd, C_data, C_indptr, C_indices) # assertion - np.allclose(z_ground_truth, C_data.numpy()) + tvm.testing.assert_allclose(z_ground_truth, C_data.numpy(), rtol=1e-5) def test_bmm(): - # TODO(zihao) - pass + # generate random input + batch_size = 32 + n_arr = np.random.randint(128, 1024, size=(batch_size,)).astype("int32") + m_arr = np.random.randint(128, 1024, size=(batch_size,)).astype("int32") + k_arr = np.random.randint(128, 1024, size=(batch_size,)).astype("int32") + nm_arr = n_arr * m_arr + mk_arr = m_arr * k_arr + nk_arr = n_arr * k_arr + indptr_n = np.concatenate(([0], n_arr)).cumsum() + indptr_m = np.concatenate(([0], m_arr)).cumsum() + indptr_k = np.concatenate(([0], k_arr)).cumsum() + indptr_nm = np.concatenate(([0], nm_arr)).cumsum() + indptr_mk = np.concatenate(([0], mk_arr)).cumsum() + indptr_nk = np.concatenate(([0], nk_arr)).cumsum() + nnz_ij = indptr_nm[-1] + nnz_jk = indptr_mk[-1] + nnz_ik = indptr_nk[-1] + As = [ + np.random.rand(n, m).astype("float32") for n, m in zip(n_arr, m_arr) + ] + Bs = [ + np.random.rand(m, k).astype("float32") for m, k in zip(m_arr, k_arr) + ] + Cs = [ + np.matmul(A, B) for A, B in zip(As, Bs) + ] + A_flatten = np.concatenate([A.flatten() for A in As], 0) + B_flatten = np.concatenate([B.flatten() for B in Bs], 0) + c_flatten = np.concatenate([C.flatten() for C in Cs], 0) + + # specialize function + _, _, _, _, _, _, _, _, _, BATCH, NNZ_IJ, NNZ_JK, NNZ_IK = bmm_tir.params + sch = tir.Schedule( + bmm_tir.specialize({ + BATCH: batch_size, NNZ_IJ: nnz_ij, NNZ_JK: nnz_jk, NNZ_IK: nnz_ik + }) + ) + bmm_outer = sch.get_block("bmm_outer") + b, = sch.get_loops(bmm_outer) + bmm_inner = sch.get_block("bmm_inner") + i, j, k = sch.get_loops(bmm_inner) + sch.reorder(i, k, j) + io, ii = sch.split(i, [None, 32]) + ko, ki = sch.split(k, [None, 32]) + sch.bind(b, "blockIdx.x") + sch.bind(ki, "threadIdx.x") + sch.bind(ii, "threadIdx.y") + sch.decompose_reduction(bmm_inner, j) + + # convert numpy tensor to tvm ndarray + dev = tvm.cuda(0) + A_nd = tvm.nd.array(A_flatten, device=dev) + B_nd = tvm.nd.array(B_flatten, device=dev) + C_nd = tvm.nd.array(np.zeros_like(c_flatten), device=dev) + indptr_n_nd = tvm.nd.array(indptr_n.astype("int32"), device=dev) + indptr_m_nd = tvm.nd.array(indptr_m.astype("int32"), device=dev) + indptr_k_nd = tvm.nd.array(indptr_k.astype("int32"), device=dev) + indptr_nm_nd = tvm.nd.array(indptr_nm.astype("int32"), device=dev) + indptr_mk_nd = tvm.nd.array(indptr_mk.astype("int32"), device=dev) + indptr_nk_nd = tvm.nd.array(indptr_nk.astype("int32"), device=dev) + + # build function + f = tvm.build(sch.mod["main"], target="cuda") + print(f.imported_modules[0].get_source()) + f(A_nd, B_nd, C_nd, indptr_n_nd, indptr_m_nd, indptr_k_nd, indptr_nm_nd, indptr_mk_nd, indptr_nk_nd) + + # assertion + tvm.testing.assert_allclose(C_nd.numpy(), c_flatten, rtol=1e-5) if __name__ == "__main__": diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py index 70b2a83f6b82..07a7849bf1f2 100644 --- a/tests/python/sparsetir/test_tir_sparse_lower.py +++ b/tests/python/sparsetir/test_tir_sparse_lower.py @@ -23,6 +23,7 @@ import scipy.sparse as sp import numpy as np from tvm.script import tir as T +from tvm.tir.sparse import AxisTree @T.prim_func @@ -50,43 +51,30 @@ def csrmm( @T.prim_func -def lowered_csrmm( - a: T.handle, - b: T.handle, - c: T.handle, - indptr: T.handle, - indices: T.handle, - n: T.int32, - m: T.int32, - k: T.int32, - nnz: T.int32, -) -> None: +def lowered_csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, n: T.int32, m: T.int32, k: T.int32, nnz: T.int32) -> None: A_data = T.match_buffer(a, [nnz], dtype="float32") B_data = T.match_buffer(b, [m * k], dtype="float32") C_data = T.match_buffer(c, [n * k], dtype="float32") J_indptr = T.match_buffer(indptr, [n + 1], dtype="int32") J_indices = T.match_buffer(indices, [nnz], dtype="int32") - for v_vi in T.serial(0, n): - for v_vj, v_vk in T.grid(J_indptr[v_vi + 1] - J_indptr[v_vi], k): - with T.block("csrmm"): - T.block_attr({"sparse": True}) - vi, vj, vk = T.axis.remap("SRS", [v_vi, v_vj, v_vk]) - T.reads( - [ - J_indptr[0 : n + 1], - J_indices[0:nnz], - A_data[0:nnz], - B_data[0 : m * k], - C_data[0 : n * k], - ] - ) - T.writes([C_data[0 : n * k]]) - with T.init(): - C_data[vi * k + vk] = T.float32(0) - C_data[vi * k + vk] = ( - C_data[vi * k + vk] - + A_data[J_indptr[vi] + vj] * B_data[J_indices[J_indptr[vi] + vj] * k + vk] - ) + for v_vi, v_vk in T.grid(n, k): + with T.block("csrmm_outer"): + vi, vk = T.axis.remap("SS", [v_vi, v_vk]) + T.reads([J_indptr[0: n + 1], J_indices[0: nnz], + A_data[0: nnz], B_data[0: m * k], C_data[0: n * k]]) + T.writes([C_data[0: n * k]]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): + with T.block("csrmm"): + vj = T.axis.reduce(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj) + T.reads([J_indptr[0: n + 1], J_indices[0: nnz], + A_data[0: nnz], B_data[0: m * k], C_data[0: n * k]]) + T.writes([C_data[0: n * k]]) + T.block_attr({"sparse": True}) + with T.init(): + C_data[vi * k + vk] = T.float32(0) + C_data[vi * k + vk] = C_data[vi * k + vk] + A_data[J_indptr[vi] + vj] * \ + B_data[J_indices[J_indptr[vi] + vj] * k + vk] @T.prim_func @@ -110,29 +98,26 @@ def csr_reduce( @T.prim_func -def lowered_csr_reduce( - a: T.handle, - b: T.handle, - indptr: T.handle, - indices: T.handle, - n: T.int32, - m: T.int32, - nnz: T.int32, -) -> None: +def lowered_csr_reduce(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle, n: T.int32, m: T.int32, nnz: T.int32) -> None: A_data = T.match_buffer(a, [nnz], dtype="float32") B_data = T.match_buffer(b, [n], dtype="float32") J_indptr = T.match_buffer(indptr, [n + 1], dtype="int32") J_indices = T.match_buffer(indices, [nnz], dtype="int32") for v_vi in T.serial(0, n): - for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): - with T.block("csr_reduce"): - T.block_attr({"sparse": True}) - vi, vj = T.axis.remap("SR", [v_vi, v_vj]) - T.reads([J_indptr[0 : n + 1], J_indices[0:nnz], A_data[0:nnz], B_data[0:n]]) - T.writes([B_data[0:n]]) - with T.init(): - B_data[vi] = T.float32(0) - B_data[vi] = B_data[vi] + A_data[J_indptr[vi] + vj] + with T.block("csr_reduce_outer"): + vi = T.axis.spatial(n, v_vi) + T.reads([J_indptr[0: n + 1], J_indices[0: nnz], A_data[0: nnz], B_data[0: n]]) + T.writes([B_data[0: n]]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): + with T.block("csr_reduce"): + vj = T.axis.reduce(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj) + T.reads([J_indptr[0: n + 1], J_indices[0: nnz], A_data[0: nnz], B_data[0: n]]) + T.writes([B_data[0: n]]) + T.block_attr({"sparse": True}) + with T.init(): + B_data[vi] = T.float32(0) + B_data[vi] = B_data[vi] + A_data[J_indptr[vi] + vj] @T.prim_func @@ -170,47 +155,30 @@ def bsrmm( @T.prim_func -def lowered_bsrmm( - a: T.handle, - b: T.handle, - c: T.handle, - indptr: T.handle, - indices: T.handle, - nb: T.int32, - mb: T.int32, - nnzb: T.int32, - blk: T.int32, - feat_size: T.int32, -) -> None: +def lowered_bsrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, nb: T.int32, mb: T.int32, nnzb: T.int32, blk: T.int32, feat_size: T.int32) -> None: A_data = T.match_buffer(a, [nnzb * blk * blk], dtype="float32") B_data = T.match_buffer(b, [mb * blk * feat_size], dtype="float32") C_data = T.match_buffer(c, [nb * blk * feat_size], dtype="float32") J_indptr = T.match_buffer(indptr, [nb + 1], dtype="int32") J_indices = T.match_buffer(indices, [nnzb], dtype="int32") - for v_vi in T.serial(0, nb): - for v_vj, v_vbi, v_vbj, v_vf in T.grid( - J_indptr[v_vi + 1] - J_indptr[v_vi], blk, blk, feat_size - ): - with T.block("bsrmm"): - T.block_attr({"sparse": True}) - vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf]) - T.reads( - [ - J_indptr[0 : nb + 1], - J_indices[0:nnzb], - A_data[0 : nnzb * blk * blk], - B_data[0 : mb * blk * feat_size], - C_data[0 : nb * blk * feat_size], - ] - ) - T.writes([C_data[0 : nb * blk * feat_size]]) - with T.init(): - C_data[(vi * blk + vbi) * feat_size + vf] = T.float32(0) - C_data[(vi * blk + vbi) * feat_size + vf] = ( - C_data[(vi * blk + vbi) * feat_size + vf] - + A_data[((J_indptr[vi] + vj) * blk + vbi) * blk + vbj] - * B_data[(J_indices[J_indptr[vi] + vj] * blk + vbj) * feat_size + vf] - ) + for v_vi, v_vbi, v_vbj, v_vf in T.grid(nb, blk, blk, feat_size): + with T.block("bsrmm_outer"): + vi, vbi, vbj, vf = T.axis.remap("SSRS", [v_vi, v_vbi, v_vbj, v_vf]) + T.reads([J_indptr[0: nb + 1], J_indices[0: nnzb], A_data[0: nnzb * blk * blk], + B_data[0: mb * blk * feat_size], C_data[0: nb * blk * feat_size]]) + T.writes([C_data[0: nb * blk * feat_size]]) + T.block_attr({"sparse": True}) + with T.init(): + C_data[(vi * blk + vbi) * feat_size + vf] = T.float32(0) + for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): + with T.block("bsrmm"): + vj = T.axis.reduce(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj) + T.reads([J_indptr[0: nb + 1], J_indices[0: nnzb], A_data[0: nnzb * blk * blk], + B_data[0: mb * blk * feat_size], C_data[0: nb * blk * feat_size]]) + T.writes([C_data[0: nb * blk * feat_size]]) + T.block_attr({"sparse": True}) + C_data[(vi * blk + vbi) * feat_size + vf] = C_data[(vi * blk + vbi) * feat_size + vf] + A_data[( + (J_indptr[vi] + vj) * blk + vbi) * blk + vbj] * B_data[(J_indices[J_indptr[vi] + vj] * blk + vbj) * feat_size + vf] @T.prim_func @@ -235,7 +203,7 @@ def ellpack_mm( B = T.match_sparse_buffer(b, (T.to_dense(J), BJ, F), mb * blk * feat_size, "float32") C = T.match_sparse_buffer(c, (I, BI, F), nb * blk * feat_size, "float32") - with T.iter([T.cord(I), T.pos(J), T.cord(BI), T.cord(BJ), T.cord(F)], "SRSRS", "bsrmm") as [ + with T.iter([T.cord(I), T.pos(J), T.cord(BI), T.cord(BJ), T.cord(F)], "SRSRS", "ellmm") as [ vi, vj, vbi, @@ -248,42 +216,22 @@ def ellpack_mm( @T.prim_func -def lowered_ellpack_mm( - a: T.handle, - b: T.handle, - c: T.handle, - indices: T.handle, - nb: T.int32, - mb: T.int32, - feat_size: T.int32, - nnz: T.int32, - col: T.int32, - blk: T.int32, -) -> None: +def lowered_ellpack_mm(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, nb: T.int32, mb: T.int32, feat_size: T.int32, nnz: T.int32, col: T.int32, blk: T.int32) -> None: A_data = T.match_buffer(a, [nnz * blk * blk], dtype="float32") B_data = T.match_buffer(b, [mb * blk * feat_size], dtype="float32") C_data = T.match_buffer(c, [nb * blk * feat_size], dtype="float32") J_indices = T.match_buffer(indices, [nnz], dtype="int32") for v_vi, v_vj, v_vbi, v_vbj, v_vf in T.grid(nb, col, blk, blk, feat_size): - with T.block("bsrmm"): - T.block_attr({"sparse": True}) + with T.block("ellmm"): vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf]) - T.reads( - [ - J_indices[0:nnz], - A_data[0 : nnz * blk * blk], - B_data[0 : mb * blk * feat_size], - C_data[0 : nb * blk * feat_size], - ] - ) - T.writes([C_data[0 : nb * blk * feat_size]]) + T.reads([J_indices[0: nnz], A_data[0: nnz * blk * blk], + B_data[0: mb * blk * feat_size], C_data[0: nb * blk * feat_size]]) + T.writes([C_data[0: nb * blk * feat_size]]) + T.block_attr({"sparse": True}) with T.init(): C_data[(vi * blk + vbi) * feat_size + vf] = T.float32(0) - C_data[(vi * blk + vbi) * feat_size + vf] = ( - C_data[(vi * blk + vbi) * feat_size + vf] - + A_data[((vi * col + vj) * blk + vbi) * blk + vbj] - * B_data[(J_indices[vi * col + vj] * blk + vbj) * feat_size + vf] - ) + C_data[(vi * blk + vbi) * feat_size + vf] = C_data[(vi * blk + vbi) * feat_size + vf] + A_data[((vi * + col + vj) * blk + vbi) * blk + vbj] * B_data[(J_indices[vi * col + vj] * blk + vbj) * feat_size + vf] @T.prim_func @@ -347,32 +295,34 @@ def csr_element_wise( @T.prim_func -def lowered_csr_element_wise( - a: T.handle, - b: T.handle, - indptr: T.handle, - indices: T.handle, - m: T.int32, - n: T.int32, - nnz: T.int32, -) -> None: +def lowered_csr_element_wise(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, nnz: T.int32) -> None: A_data = T.match_buffer(a, [nnz], dtype="float32") B_data = T.match_buffer(b, [nnz], dtype="float32") J_indptr = T.match_buffer(indptr, [m + 1], dtype="int32") J_indices = T.match_buffer(indices, [nnz], dtype="int32") for v_vi in T.serial(0, m): - for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): - with T.block("csr_element_wise"): - T.block_attr({"sparse": True}) - vi, vj = T.axis.remap("SS", [v_vi, v_vj]) - T.reads([J_indptr[0 : m + 1], J_indices[0:nnz], A_data[0:nnz]]) - T.writes([B_data[0:nnz]]) - B_data[J_indptr[vi] + vj] = A_data[J_indptr[vi] + vj] * T.float32(2.5) + with T.block("csr_element_wise_outer"): + vi = T.axis.spatial(m, v_vi) + T.reads([J_indptr[0: m + 1], J_indices[0: nnz], A_data[0: nnz]]) + T.writes([B_data[0: nnz]]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): + with T.block("csr_element_wise"): + vj = T.axis.spatial(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj) + T.reads([J_indptr[0: m + 1], J_indices[0: nnz], A_data[0: nnz]]) + T.writes([B_data[0: nnz]]) + T.block_attr({"sparse": True}) + B_data[J_indptr[vi] + vj] = A_data[J_indptr[vi] + vj] * T.float32(2.5) def test_csrmm(): mod = tvm.IRModule.from_expr(csrmm) - mod = tvm.tir.transform.LowerSparseTIR()(mod) + t = AxisTree({ + "J": "I", + "I": None, + "K": None + }) + mod = tvm.tir.transform.LowerSparseTIR(t)(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True) A = sp.random(512, 512, dtype="float32", density=0.0125, format="csr") @@ -395,7 +345,11 @@ def test_csrmm(): def test_csr_reduce(): mod = tvm.IRModule.from_expr(csr_reduce) - mod = tvm.tir.transform.LowerSparseTIR()(mod) + t = AxisTree({ + "J": "I", + "I": None + }) + mod = tvm.tir.transform.LowerSparseTIR(t)(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_csr_reduce, True) A = sp.random(128, 128, dtype="float32", density=0.0125, format="csr") @@ -416,7 +370,14 @@ def test_csr_reduce(): def test_bsrmm(): mod = tvm.IRModule.from_expr(bsrmm) - mod = tvm.tir.transform.LowerSparseTIR()(mod) + t = AxisTree({ + "J": "I", + "I": None, + "BJ": None, + "BI": None, + "F": None + }) + mod = tvm.tir.transform.LowerSparseTIR(t)(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_bsrmm, True) block_size = 16 @@ -456,7 +417,14 @@ def test_bsrmm(): def test_ellpack_mm(): mod = tvm.IRModule.from_expr(ellpack_mm) - mod = tvm.tir.transform.LowerSparseTIR()(mod) + t = AxisTree({ + "J": "I", + "I": None, + "F": None, + "BI": None, + "BJ": None + }) + mod = tvm.tir.transform.LowerSparseTIR(t)(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_ellpack_mm, True) nnz_cols = 4 @@ -505,13 +473,18 @@ def test_ellpack_mm(): def test_batch_mm(): mod = tvm.IRModule.from_expr(batch_mm) - mod = tvm.tir.transform.LowerSparseTIR()(mod) + t = AxisTree({}) + mod = tvm.tir.transform.LowerSparseTIR(t)(mod) # print(mod["main"].script(tir_prefix="T")) def test_csr_element_wise(): mod = tvm.IRModule.from_expr(csr_element_wise) - mod = tvm.tir.transform.LowerSparseTIR()(mod) + t = AxisTree({ + "J": "I", + "I": None + }) + mod = tvm.tir.transform.LowerSparseTIR(t)(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_csr_element_wise, True) A = sp.random(128, 128, dtype="float32", density=0.0125, format="csr") @@ -535,5 +508,5 @@ def test_csr_element_wise(): test_csr_reduce() test_bsrmm() test_ellpack_mm() - test_batch_mm() + # test_batch_mm() test_csr_element_wise()