diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index e1cd6d3f3bc3..59938d0af7a3 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -955,7 +955,7 @@ class Attach(SpecialStmt): def __init__(self): def attach_axis( parent: Axis, - orig: Axis, + orig: DenseVariableAxis, nnz: PrimExpr, indptr_var: tvm.tir.Var, idtype: str = "int32", @@ -967,7 +967,7 @@ def attach_axis( f"`attach_axis` expected assign to only one var, but got {names}", span ) - indptr_len = orig.nnz + 1 + indptr_len = orig.parent.length + 1 indptr_buf = tvm.tir.decl_buffer( (indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span ) diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py index d75e4337f483..df3a3cb82c0b 100644 --- a/python/tvm/tir/sparse.py +++ b/python/tvm/tir/sparse.py @@ -47,6 +47,10 @@ def idtype(self): def nnz(self): return _ffi_api.GetNNZ(self) + @property + def parent(self): + return _ffi_api.GetParent(self) + @tvm._ffi.register_object("tir.sparse.DenseAxis") class DenseAxis(Axis): @@ -168,9 +172,9 @@ class AttachedAxis(DenseVariableAxis): nnz : PrimExpr indptr : PrimExpr - def __init__(self, name, parent, length, nnz, indptr): + def __init__(self, name, parent, orig, nnz, indptr): self.__init_handle_by_constructor__( - _ffi_api.AttachedAxis, name, parent, length, nnz, indptr + _ffi_api.AttachedAxis, name, parent, orig, nnz, indptr ) diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index 6709795dc4cb..d68acceb473c 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -46,6 +46,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.GetAxisIndexType").set_body_typed([](Axis axis) TVM_REGISTER_GLOBAL("tir.sparse.GetNNZ").set_body_typed([](Axis axis) { return axis->GetNNZ(); }); +TVM_REGISTER_GLOBAL("tir.sparse.GetParent").set_body_typed([](Axis axis) { return axis->GetParentAxis(); }); + /******** AxisNode ********/ std::tuple AxisNode::GetOffsetExtent(SparseCtx* ctx) const { diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index 4bdbf688d34b..7338a59066ee 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -134,9 +134,15 @@ class SparseBlockCtx : public SparseCtx { Axis orig = group[j]; SetOffset(orig, offset); if (j > 0) { - // TODO(zihao): support more than sv axis. - offset = lower_bound(Downcast(orig)->indptr->data, offset, - Integer(0), orig->GetNNZ()); + Buffer indptr; + if (auto sv_axis = orig.as()) { + indptr = sv_axis->indptr; + } else if (auto dv_axis = orig.as()) { + indptr = dv_axis->indptr; + } else { + throw; + } + offset = upper_bound(indptr->data, offset, Integer(0), indptr->shape[0]) - 1; } } for (size_t j = 0; j < group.size(); ++j) { @@ -379,7 +385,7 @@ class IndexTransformer : public StmtExprMutator { */ IterVar SpIterVarToIterVar(const SpIterVar& sp_iter, Map var_map) { // Substitute the iteration vars in the expression with the loop vars. - return IterVar(Range::FromMinExtent(0, Substitute(sp_blk_ctx_.GetIterExtent(sp_iter), var_map)), + return IterVar(Range::FromMinExtent(0, sp_blk_ctx_.GetIterExtent(sp_iter)), sp_iter->var, sp_iter->is_reduction ? kCommReduce : kDataPar); } diff --git a/tests/python/sparsetir/test_tir_rgcn.py b/tests/python/sparsetir/bench_rgcn.py similarity index 64% rename from tests/python/sparsetir/test_tir_rgcn.py rename to tests/python/sparsetir/bench_rgcn.py index 1ea88bffa7d3..f0a2d73129bf 100644 --- a/tests/python/sparsetir/test_tir_rgcn.py +++ b/tests/python/sparsetir/bench_rgcn.py @@ -9,6 +9,8 @@ import torch as th from tvm.script import tir as T from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset +from lowered_tir import lowered_rgcn_forward +from sparse_tir_scripts import rgcn_forward class TorchOpTimer(object): @@ -63,69 +65,6 @@ def prepare_graph(g, ntype=None): return g -@T.prim_func -def rgcn( - etype: T.handle, - w: T.handle, - x: T.handle, - y: T.handle, - indptr: T.handle, - indices: T.handle, - n: T.int32, - r: T.int32, - feat_size: T.int32, - nnz: T.int32 -): - I = T.dense_fixed(n) - J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") - R = T.dense_fixed(r) - F_in = T.dense_fixed(feat_size) - F_out = T.dense_fixed(feat_size) - E = T.match_sparse_buffer(etype, (I, J), "int32") - W = T.match_sparse_buffer(w, (R, F_out, F_in), "float32") - X = T.match_sparse_buffer(x, (T.dense(J), F_in), "float32") - Y = T.match_sparse_buffer(y, (I, F_out), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.iter([I, F_out, J, F_in], "SSRR", "rgcn-forward") as [ - vi, vout, vj, vin, - ]: - with T.init(): - Y[vi, vout] = 0. - Y[vi, vout] = Y[vi, vout] + W[E[vi, vj], vout, vin] * X[vj, vin] - - -@T.prim_func -def lowered_rgcn(etype: T.handle, w: T.handle, x: T.handle, y: T.handle, indptr: T.handle, indices: T.handle, n: T.int32, r: T.int32, feat_size: T.int32, nnz: T.int32) -> None: - E_data = T.match_buffer(etype, [nnz], dtype="int32") - W_data = T.match_buffer(w, [r * feat_size * feat_size], dtype="float32") - X_data = T.match_buffer(x, [n * feat_size], dtype="float32") - Y_data = T.match_buffer(y, [n * feat_size], dtype="float32") - J_indptr = T.match_buffer(indptr, [n + 1], dtype="int32") - J_indices = T.match_buffer(indices, [nnz], dtype="int32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - for v_vi, v_vout in T.grid(n, feat_size): - with T.block("rgcn-forward_0"): - vi, vout = T.axis.remap("SS", [v_vi, v_vout]) - T.reads(J_indptr[0: n + 1], J_indices[0: nnz], E_data[0: nnz], W_data[0: r * - feat_size * feat_size], X_data[0: n * feat_size], Y_data[0: n * feat_size]) - T.writes(Y_data[0: n * feat_size]) - T.block_attr({"sparse": True}) - for v_vj in T.serial(J_indptr[v_vi + 1] - J_indptr[v_vi]): - for v_vin in T.serial(feat_size): - with T.block("rgcn-forward_1"): - vj, vin = T.axis.remap("RR", [v_vj, v_vin]) - T.reads(J_indptr[0: n + 1], J_indices[0: nnz], E_data[0: nnz], W_data[0: r * - feat_size * feat_size], X_data[0: n * feat_size], Y_data[0: n * feat_size]) - T.writes(Y_data[0: n * feat_size]) - T.block_attr({"sparse": True}) - with T.init(): - Y_data[vi * feat_size + vout] = T.float32(0) - Y_data[vi * feat_size + vout] = Y_data[vi * feat_size + vout] + W_data[( - E_data[J_indptr[vi] + vj] * feat_size + vout) * feat_size + vin] * X_data[J_indices[J_indptr[vi] + vj] * feat_size + vin] - - def test_rgcn(g: DGLHeteroGraph): feat_size = 16 g = g.to(0) @@ -180,9 +119,9 @@ def msg_func(edges): print("dgl high-mem:\t\t", accum / (total - cold_start)) # tir - mod = tvm.IRModule.from_expr(rgcn) + mod = tvm.IRModule.from_expr(rgcn_forward) mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_rgcn, True) + tvm.ir.assert_structural_equal(mod["main"], lowered_rgcn_forward, True) N, R, FEAT_SIZE, NNZ = mod["main"].params[-4:] sch = tir.Schedule( diff --git a/tests/python/sparsetir/lowered_tir.py b/tests/python/sparsetir/lowered_tir.py new file mode 100644 index 000000000000..360a709efa5c --- /dev/null +++ b/tests/python/sparsetir/lowered_tir.py @@ -0,0 +1,396 @@ +"""Lowered TIR scripts of sparse workloads.""" +from tvm.script import tir as T + + +@T.prim_func +def lowered_csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, (nnz,), "float32") + B_data = T.match_buffer(b, (n * k,), "float32") + C_data = T.match_buffer(c, (m * k,), "float32") + J_indptr = T.match_buffer(indptr, (m + 1,), "int32") + J_indices = T.match_buffer(indices, (nnz,), "int32") + # body + # with T.block("root") + for v_vi, v_vk in T.grid(m, k): + with T.block("csrmm0"): + vi, vk = T.axis.remap("SS", [v_vi, v_vk]) + T.reads(J_indptr[0: m + 1], J_indices[0: nnz], + A_data[0: nnz], B_data[0: n * k], C_data[0: m * k]) + T.writes(C_data[0: m * k]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(J_indptr[vi + 1] - J_indptr[vi]): + with T.block("csrmm1"): + vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) + T.reads(J_indptr[0: m + 1], J_indices[0: nnz], + A_data[0: nnz], B_data[0: n * k], C_data[0: m * k]) + T.writes(C_data[0: m * k]) + T.block_attr({"sparse": True}) + with T.init(): + C_data[vi * k + vk] = T.float32(0) + C_data[vi * k + vk] = C_data[vi * k + vk] + A_data[J_indptr[vi] + vj] * \ + B_data[J_indices[J_indptr[vi] + vj] * k + vk] + + +@T.prim_func +def lowered_csrmm_dense_iter(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, (nnz,), "float32") + B_data = T.match_buffer(b, (n * k,), "float32") + C_data = T.match_buffer(c, (m * k,), "float32") + J_indptr = T.match_buffer(indptr, (m + 1,), "int32") + J_indices = T.match_buffer(indices, (nnz,), "int32") + # body + # with T.block("root") + for v_vi, v_vj, v_vk in T.grid(m, n, k): + with T.block("csrmm0"): + vi, vj, vk = T.axis.remap("SRS", [v_vi, v_vj, v_vk]) + T.reads(J_indptr[0: m + 1], J_indices[0: nnz], + A_data[0: nnz], B_data[0: n * k], C_data[0: m * k]) + T.writes(C_data[0: m * k]) + T.block_attr({"sparse": True}) + with T.init(): + C_data[vi * k + vk] = T.float32(0) + C_data[vi * k + vk] = C_data[vi * k + vk] + A_data[T.tvm_lower_bound( + J_indices.data, vj, J_indptr[vi], J_indptr[vi + 1], dtype="int32")] * B_data[vj * k + vk] + + +@T.prim_func +def lowered_csr_reduce(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle, n: T.int32, m: T.int32, nnz: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, [nnz], dtype="float32") + B_data = T.match_buffer(b, [n], dtype="float32") + J_indptr = T.match_buffer(indptr, [n + 1], dtype="int32") + J_indices = T.match_buffer(indices, [nnz], dtype="int32") + for v_vi in T.serial(0, n): + with T.block("csr_reduce_outer"): + vi = T.axis.spatial(n, v_vi) + T.reads([J_indptr[0: n + 1], J_indices[0: nnz], A_data[0: nnz], B_data[0: n]]) + T.writes([B_data[0: n]]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(0, J_indptr[vi + 1] - J_indptr[vi]): + with T.block("csr_reduce"): + vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) + T.reads([J_indptr[0: n + 1], J_indices[0: nnz], A_data[0: nnz], B_data[0: n]]) + T.writes([B_data[0: n]]) + T.block_attr({"sparse": True}) + with T.init(): + B_data[vi] = T.float32(0) + B_data[vi] = B_data[vi] + A_data[J_indptr[vi] + vj] + + +@T.prim_func +def lowered_segment_reduce(a: T.handle, b: T.handle, indptr: T.handle, n: T.int32, nnz: T.int32) -> None: + A_data = T.match_buffer(a, (nnz,), "float32") + B_data = T.match_buffer(b, (n,), "float32") + J_indptr = T.match_buffer(indptr, (n + 1,), "int32") + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for v_vi in T.serial(n): + with T.block("segment_reduce0"): + vi = T.axis.spatial(n, v_vi) + T.reads(J_indptr[0: n + 1], A_data[0: nnz], B_data[0: n]) + T.writes(B_data[0: n]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(J_indptr[vi + 1] - J_indptr[vi]): + with T.block("segment_reduce1"): + vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) + T.reads(J_indptr[0: n + 1], A_data[0: nnz], B_data[0: n]) + T.writes(B_data[0: n]) + T.block_attr({"sparse": True}) + with T.init(): + B_data[vi] = T.float32(0) + B_data[vi] = B_data[vi] + A_data[J_indptr[vi] + vj] + + +@T.prim_func +def lowered_bsrmm(a: T.handle, b: T.handle, c: T.handle, j_indptr: T.handle, j_indices: T.handle, nb: T.int32, mb: T.int32, nnzb: T.int32, blk: T.int32, feat_size: T.int32) -> None: + A_data = T.match_buffer(a, (nnzb * blk * blk,), "float32") + B_data = T.match_buffer(b, (mb * blk * feat_size,), "float32") + C_data = T.match_buffer(c, (nb * blk * feat_size,), "float32") + J_indptr = T.match_buffer(j_indptr, (nb + 1,), "int32") + J_indices = T.match_buffer(j_indices, (nnzb,), "int32") + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for v_vi, v_vbi, v_vbj, v_vf in T.grid(nb, blk, blk, feat_size): + with T.block("bsrmm0"): + vi, vbi, vbj, vf = T.axis.remap("SSRS", [v_vi, v_vbi, v_vbj, v_vf]) + T.reads(J_indptr[0: nb + 1], J_indices[0: nnzb], A_data[0: nnzb * blk * blk], + B_data[0: mb * blk * feat_size], C_data[0: nb * blk * feat_size]) + T.writes(C_data[0: nb * blk * feat_size]) + T.block_attr({"sparse": True}) + with T.init(): + C_data[(vi * blk + vbi) * feat_size + vf] = T.float32(0) + for v_vj in T.serial(J_indptr[vi + 1] - J_indptr[vi]): + with T.block("bsrmm1"): + vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) + T.reads(J_indptr[0: nb + 1], J_indices[0: nnzb], A_data[0: nnzb * blk * blk], + B_data[0: mb * blk * feat_size], C_data[0: nb * blk * feat_size]) + T.writes(C_data[0: nb * blk * feat_size]) + T.block_attr({"sparse": True}) + C_data[(vi * blk + vbi) * feat_size + vf] = C_data[(vi * blk + vbi) * feat_size + vf] + A_data[( + (J_indptr[vi] + vj) * blk + vbi) * blk + vbj] * B_data[(J_indices[J_indptr[vi] + vj] * blk + vbj) * feat_size + vf] + + +@T.prim_func +def lowered_ellmm(a: T.handle, b: T.handle, c: T.handle, j_indices: T.handle, nb: T.int32, mb: T.int32, feat_size: T.int32, col: T.int32, blk: T.int32) -> None: + A_data = T.match_buffer(a, (nb * col * blk * blk,), "float32") + B_data = T.match_buffer(b, (mb * blk * feat_size,), "float32") + C_data = T.match_buffer(c, (nb * blk * feat_size,), "float32") + J_indices = T.match_buffer(j_indices, (nb * col,), "int32") + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for v_vi, v_vj, v_vbi, v_vbj, v_vf in T.grid(nb, col, blk, blk, feat_size): + with T.block("ellmm0"): + vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf]) + T.reads(J_indices[0: nb * col], A_data[0: nb * col * blk * blk], + B_data[0: mb * blk * feat_size], C_data[0: nb * blk * feat_size]) + T.writes(C_data[0: nb * blk * feat_size]) + T.block_attr({"sparse": True}) + with T.init(): + C_data[(vi * blk + vbi) * feat_size + vf] = T.float32(0) + C_data[(vi * blk + vbi) * feat_size + vf] = C_data[(vi * blk + vbi) * feat_size + vf] + A_data[((vi * + col + vj) * blk + vbi) * blk + vbj] * B_data[(J_indices[vi * col + vj] * blk + vbj) * feat_size + vf] + + +@T.prim_func +def lowered_sddmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, (m * k,), "float32") + B_data = T.match_buffer(b, (n * k,), "float32") + C_data = T.match_buffer(c, (nnz,), "float32") + J_indptr = T.match_buffer(indptr, (m + 1,), "int32") + J_indices = T.match_buffer(indices, (nnz,), "int32") + for v_vi in T.serial(m): + with T.block("sddmm0"): + vi = T.axis.spatial(m, v_vi) + T.reads(J_indptr[0: m + 1], J_indices[0: nnz], + A_data[0: m * k], B_data[0: n * k], C_data[0: nnz]) + T.writes(C_data[0: nnz]) + T.block_attr({"sparse": True}) + for v_vj, v_vk in T.grid(J_indptr[vi + 1] - J_indptr[vi], k): + with T.block("sddmm1"): + vj, vk = T.axis.remap("SR", [v_vj, v_vk]) + T.reads(J_indptr[0: m + 1], J_indices[0: nnz], + A_data[0: m * k], B_data[0: n * k], C_data[0: nnz]) + T.writes(C_data[0: nnz]) + T.block_attr({"sparse": True}) + with T.init(): + C_data[J_indptr[vi] + vj] = T.float32(0) + C_data[J_indptr[vi] + vj] = C_data[J_indptr[vi] + vj] + \ + A_data[vi * k + vk] * B_data[J_indices[J_indptr[vi] + vj] * k + vk] + + +# from tvm.script import tir as T +@T.prim_func +def lowered_sddmm_fuse(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, (m * k,), "float32") + B_data = T.match_buffer(b, (n * k,), "float32") + C_data = T.match_buffer(c, (nnz,), "float32") + J_indptr = T.match_buffer(indptr, (m + 1,), "int32") + J_indices = T.match_buffer(indices, (nnz,), "int32") + # body + # with T.block("root") + for v_vi, v_vj, v_vk in T.grid(1, nnz, k): + with T.block("sddmm0"): + vi, vj, vk = T.axis.remap("SSR", [v_vi, v_vj, v_vk]) + T.reads(J_indptr[0: m + 1], J_indices[0: nnz], + A_data[0: m * k], B_data[0: n * k], C_data[0: nnz]) + T.writes(C_data[0: nnz]) + T.block_attr({"sparse": True}) + with T.init(): + C_data[vj] = T.float32(0) + C_data[vj] = C_data[vj] + A_data[(T.tvm_upper_bound(J_indptr.data, vj, 0, + m + 1, dtype="int32") - 1) * k + vk] * B_data[J_indices[vj] * k + vk] + + +@T.prim_func +def lowered_bmm( + x: T.handle, + y: T.handle, + z: T.handle, + indptr_i: T.handle, + indptr_j: T.handle, + indptr_k: T.handle, + indptr_ij: T.handle, + indptr_jk: T.handle, + indptr_ik: T.handle, + batch_size: T.int32, + nnz_i: T.int32, + nnz_j: T.int32, + nnz_k: T.int32, + nnz_ij: T.int32, + nnz_jk: T.int32, + nnz_ik: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + X_data = T.match_buffer(x, (nnz_ij,), "float32") + Y_data = T.match_buffer(y, (nnz_jk,), "float32") + Z_data = T.match_buffer(z, (nnz_ik,), "float32") + I_indptr = T.match_buffer(indptr_i, (batch_size + 1,), "int32") + J_indptr = T.match_buffer(indptr_j, (batch_size + 1,), "int32") + K_indptr = T.match_buffer(indptr_k, (batch_size + 1,), "int32") + IJ_indptr = T.match_buffer(indptr_ij, (batch_size + 1,), "int32") + JK_indptr = T.match_buffer(indptr_jk, (batch_size + 1,), "int32") + IK_indptr = T.match_buffer(indptr_ik, (batch_size + 1,), "int32") + # body + # with T.block("root") + for v_vb in T.serial(batch_size): + with T.block("bmm0"): + vb = T.axis.spatial(batch_size, v_vb) + T.reads(I_indptr[0: batch_size + 1], J_indptr[0: batch_size + 1], K_indptr[0: batch_size + 1], IJ_indptr[0: batch_size + 1], + JK_indptr[0: batch_size + 1], IK_indptr[0: batch_size + 1], X_data[0: nnz_ij], Y_data[0: nnz_jk], Z_data[0: nnz_ik]) + T.writes(Z_data[0: nnz_ik]) + T.block_attr({"sparse": True}) + for v_vi, v_vj, v_vk in T.grid(I_indptr[vb + 1] - I_indptr[vb], J_indptr[vb + 1] - J_indptr[vb], K_indptr[vb + 1] - K_indptr[vb]): + with T.block("bmm1"): + vi, vj, vk = T.axis.remap("SRS", [v_vi, v_vj, v_vk]) + T.reads(I_indptr[0: batch_size + 1], J_indptr[0: batch_size + 1], K_indptr[0: batch_size + 1], IJ_indptr[0: batch_size + 1], + JK_indptr[0: batch_size + 1], IK_indptr[0: batch_size + 1], X_data[0: nnz_ij], Y_data[0: nnz_jk], Z_data[0: nnz_ik]) + T.writes(Z_data[0: nnz_ik]) + T.block_attr({"sparse": True}) + with T.init(): + Z_data[IK_indptr[vb] + vi * + (K_indptr[vb + 1] - K_indptr[vb]) + vk] = T.float32(0) + Z_data[IK_indptr[vb] + vi * (K_indptr[vb + 1] - K_indptr[vb]) + vk] = Z_data[IK_indptr[vb] + vi * (K_indptr[vb + 1] - K_indptr[vb]) + vk] + \ + X_data[IJ_indptr[vb] + vi * (J_indptr[vb + 1] - J_indptr[vb]) + vj] * \ + Y_data[JK_indptr[vb] + vj * (K_indptr[vb + 1] - K_indptr[vb]) + vk] + + +@T.prim_func +def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k: T.handle, indices_k: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, [nnz_k], dtype="float32") + B_data = T.match_buffer(b, [M], dtype="float32") + J_indptr = T.match_buffer(indptr_j, [M + 1], dtype="int32") + J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32") + K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32") + K_indices = T.match_buffer(indices_k, [nnz_k], dtype="int32") + + for v_vi in T.serial(0, M): + with T.block("square_sum_2"): + vi = T.axis.spatial(M, v_vi) + T.reads([J_indptr[0: M + 1], J_indices[0: nnz_j], K_indptr[0: nnz_j + 1], + K_indices[0: nnz_k], A_data[0: nnz_k], B_data[0: M]]) + T.writes([B_data[0: M]]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(0, J_indptr[vi + 1] - J_indptr[vi]): + with T.block("square_sum_1"): + vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) + T.reads([J_indptr[0: M + 1], J_indices[0: nnz_j], K_indptr[0: nnz_j + 1], + K_indices[0: nnz_k], A_data[0: nnz_k], B_data[0: M]]) + T.writes([B_data[0: M]]) + T.block_attr({"sparse": True}) + with T.init(): + B_data[vi] = T.float32(0) + for v_vk in T.serial(0, K_indptr[J_indptr[vi] + vj + 1] - K_indptr[J_indptr[vi] + vj]): + with T.block("square_sum"): + vk = T.axis.reduce( + K_indptr[J_indptr[vi] + vj + 1] - K_indptr[J_indptr[vi] + vj], v_vk) + T.reads([J_indptr[0: M + 1], J_indices[0: nnz_j], K_indptr[0: nnz_j + 1], + K_indices[0: nnz_k], A_data[0: nnz_k], B_data[0: M]]) + T.writes([B_data[0: M]]) + T.block_attr({"sparse": True}) + B_data[vi] = B_data[vi] + A_data[K_indptr[J_indptr[vi] + vj] + vk] + + +@T.prim_func +def lowered_square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k0: T.handle, indices_k0: T.handle, indptr_k1: T.handle, indices_k1: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, [nnz_k], dtype="float32") + B_data = T.match_buffer(b, [M], dtype="float32") + J_indptr = T.match_buffer(indptr_j, [M + 1], dtype="int32") + J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32") + K0_indptr = T.match_buffer(indptr_k0, [nnz_j + 1], dtype="int32") + K0_indices = T.match_buffer(indices_k0, [nnz_k], dtype="int32") + K1_indptr = T.match_buffer(indptr_k1, [nnz_j + 1], dtype="int32") + K1_indices = T.match_buffer(indices_k1, [nnz_k], dtype="int32") + + for v_vi in T.serial(0, M): + with T.block("square_sum_2"): + vi = T.axis.spatial(M, v_vi) + T.reads([J_indptr[0: M + 1], J_indices[0: nnz_j], K0_indptr[0: nnz_j + 1], K0_indices[0: nnz_k], + K1_indptr[0: nnz_j + 1], K1_indices[0: nnz_k], A_data[0: nnz_k], B_data[0: M]]) + T.writes([B_data[0: M]]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(0, J_indptr[vi + 1] - J_indptr[vi]): + with T.block("square_sum_1"): + vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) + T.reads([J_indptr[0: M + 1], J_indices[0: nnz_j], K0_indptr[0: nnz_j + 1], K0_indices[0: nnz_k], + K1_indptr[0: nnz_j + 1], K1_indices[0: nnz_k], A_data[0: nnz_k], B_data[0: M]]) + T.writes([B_data[0: M]]) + T.block_attr({"sparse": True}) + with T.init(): + B_data[vi] = T.float32(0) + for v_vk in T.serial(0, K1_indptr[J_indptr[vi] + vj + 1] - K1_indptr[J_indptr[vi] + vj]): + with T.block("square_sum"): + vk = T.axis.reduce( + K1_indptr[J_indptr[vi] + vj + 1] - K1_indptr[J_indptr[vi] + vj], v_vk) + T.reads([J_indptr[0: M + 1], J_indices[0: nnz_j], K0_indptr[0: nnz_j + 1], K0_indices[0: nnz_k], + K1_indptr[0: nnz_j + 1], K1_indices[0: nnz_k], A_data[0: nnz_k], B_data[0: M]]) + T.writes([B_data[0: M]]) + T.block_attr({"sparse": True}) + B_data[vi] = B_data[vi] + A_data[T.tvm_lower_bound( + K0_indices.data, K1_indices[K1_indptr[J_indptr[vi] + vj] + vk], K0_indptr[J_indptr[vi] + vj], K0_indptr[J_indptr[vi] + vj + 1], dtype="int32")] + + +@T.prim_func +def lowered_csr_element_wise(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, nnz: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, [nnz], dtype="float32") + B_data = T.match_buffer(b, [nnz], dtype="float32") + J_indptr = T.match_buffer(indptr, [m + 1], dtype="int32") + J_indices = T.match_buffer(indices, [nnz], dtype="int32") + for v_vi in T.serial(0, m): + with T.block("csr_element_wise_outer"): + vi = T.axis.spatial(m, v_vi) + T.reads([J_indptr[0: m + 1], J_indices[0: nnz], A_data[0: nnz]]) + T.writes([B_data[0: nnz]]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(0, J_indptr[vi + 1] - J_indptr[vi]): + with T.block("csr_element_wise"): + vj = T.axis.spatial(J_indptr[vi + 1] - J_indptr[vi], v_vj) + T.reads([J_indptr[0: m + 1], J_indices[0: nnz], A_data[0: nnz]]) + T.writes([B_data[0: nnz]]) + T.block_attr({"sparse": True}) + B_data[J_indptr[vi] + vj] = A_data[J_indptr[vi] + vj] * T.float32(2.5) + + +@T.prim_func +def lowered_rgcn_forward(etype: T.handle, w: T.handle, x: T.handle, y: T.handle, indptr: T.handle, indices: T.handle, n: T.int32, r: T.int32, feat_size: T.int32, nnz: T.int32) -> None: + E_data = T.match_buffer(etype, [nnz], dtype="int32") + W_data = T.match_buffer(w, [r * feat_size * feat_size], dtype="float32") + X_data = T.match_buffer(x, [n * feat_size], dtype="float32") + Y_data = T.match_buffer(y, [n * feat_size], dtype="float32") + J_indptr = T.match_buffer(indptr, [n + 1], dtype="int32") + J_indices = T.match_buffer(indices, [nnz], dtype="int32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for v_vi, v_vout in T.grid(n, feat_size): + with T.block("rgcn-forward_0"): + vi, vout = T.axis.remap("SS", [v_vi, v_vout]) + T.reads(J_indptr[0: n + 1], J_indices[0: nnz], E_data[0: nnz], W_data[0: r * + feat_size * feat_size], X_data[0: n * feat_size], Y_data[0: n * feat_size]) + T.writes(Y_data[0: n * feat_size]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(J_indptr[vi + 1] - J_indptr[vi]): + for v_vin in T.serial(feat_size): + with T.block("rgcn-forward_1"): + vj, vin = T.axis.remap("RR", [v_vj, v_vin]) + T.reads(J_indptr[0: n + 1], J_indices[0: nnz], E_data[0: nnz], W_data[0: r * + feat_size * feat_size], X_data[0: n * feat_size], Y_data[0: n * feat_size]) + T.writes(Y_data[0: n * feat_size]) + T.block_attr({"sparse": True}) + with T.init(): + Y_data[vi * feat_size + vout] = T.float32(0) + Y_data[vi * feat_size + vout] = Y_data[vi * feat_size + vout] + W_data[( + E_data[J_indptr[vi] + vj] * feat_size + vout) * feat_size + vin] * X_data[J_indices[J_indptr[vi] + vj] * feat_size + vin] diff --git a/tests/python/sparsetir/sparse_tir_scripts.py b/tests/python/sparsetir/sparse_tir_scripts.py new file mode 100644 index 000000000000..778186b48942 --- /dev/null +++ b/tests/python/sparsetir/sparse_tir_scripts.py @@ -0,0 +1,313 @@ +from tvm.script import tir as T + + +@T.prim_func +def csrmm( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + m: T.int32, + n: T.int32, + k: T.int32, + nnz: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(m) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, K), "float32") + with T.iter([I, K, J], "SSR", "csrmm") as [vi, vk, vj]: + with T.init(): + C[vi, vk] = 0.0 + C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] + + +@T.prim_func +def csrmm_dense_iter( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + m: T.int32, + n: T.int32, + k: T.int32, + nnz: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(m) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, K), "float32") + with T.iter([I, T.dense(J), K], "SRS", "csrmm") as [vi, vj, vk]: + with T.init(): + C[vi, vk] = 0.0 + C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] + + +@T.prim_func +def segment_reduce( + a: T.handle, + b: T.handle, + indptr: T.handle, + n: T.int32, + nnz: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(n) + J = T.dense_variable(I, (100, nnz), indptr, "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") + with T.iter([I, J], "SR", "segment_reduce") as [vi, vj]: + with T.init(): + B[vi] = 0. + B[vi] = B[vi] + A[vi, vj] + + +@T.prim_func +def csr_reduce( + a: T.handle, + b: T.handle, + indptr: T.handle, + indices: T.handle, + n: T.int32, + m: T.int32, + nnz: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(n) + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") + with T.iter([I, J], "SR", "csr_reduce") as [vi, vj]: + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vj] + + +@T.prim_func +def bsrmm( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + nb: T.int32, + mb: T.int32, + nnzb: T.int32, + blk: T.int32, + feat_size: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(nb) + J = T.sparse_variable(I, (mb, nnzb), (indptr, indices), "int32") + BI = T.dense_fixed(blk) + BJ = T.dense_fixed(blk) + F = T.dense_fixed(feat_size) + A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") + C = T.match_sparse_buffer(c, (I, BI, F), "float32") + + with T.iter([I, BI, BJ, F, J], "SSRSR", "bsrmm") as [ + vi, + vbi, + vbj, + vf, + vj, + ]: + with T.init(): + C[vi, vbi, vf] = 0.0 + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] + + +@T.prim_func +def ellmm( + a: T.handle, + b: T.handle, + c: T.handle, + indices: T.handle, + nb: T.int32, + mb: T.int32, + feat_size: T.int32, + col: T.int32, + blk: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(nb) + J = T.sparse_fixed(I, (mb, col), indices, "int32") + F = T.dense_fixed(feat_size) + BI = T.dense_fixed(blk) + BJ = T.dense_fixed(blk) + A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") + C = T.match_sparse_buffer(c, (I, BI, F), "float32") + + with T.iter([I, J, BI, BJ, F], "SRSRS", "ellmm") as [ + vi, + vj, + vbi, + vbj, + vf, + ]: + with T.init(): + C[vi, vbi, vf] = 0.0 + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] + + +@T.prim_func +def csr_element_wise( + a: T.handle, + b: T.handle, + indptr: T.handle, + indices: T.handle, + m: T.int32, + n: T.int32, + nnz: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(m) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I, J), "float32") + + with T.iter([I, J], "SS", "csr_element_wise") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.5 + + +@T.prim_func +def bmm( + x: T.handle, + y: T.handle, + z: T.handle, + indptr_i: T.handle, + indptr_j: T.handle, + indptr_k: T.handle, + indptr_ij: T.handle, + indptr_jk: T.handle, + indptr_ik: T.handle, + batch_size: T.int32, + nnz_i: T.int32, + nnz_j: T.int32, + nnz_k: T.int32, + nnz_ij: T.int32, + nnz_jk: T.int32, + nnz_ik: T.int32 +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + B = T.dense_fixed(batch_size) + I = T.dense_variable(B, (32768, nnz_i), indptr_i, "int32") + J = T.dense_variable(B, (32768, nnz_j), indptr_j, "int32") + K = T.dense_variable(B, (32768, nnz_k), indptr_k, "int32") + IJ = T.attach_axis(I, J, nnz_ij, indptr_ij, "int32") + JK = T.attach_axis(J, K, nnz_jk, indptr_jk, "int32") + IK = T.attach_axis(I, K, nnz_ik, indptr_ik, "int32") + X = T.match_sparse_buffer(x, (B, I, IJ), "float32") + Y = T.match_sparse_buffer(y, (B, J, JK), "float32") + Z = T.match_sparse_buffer(z, (B, I, IK), "float32") + with T.iter([B, I, J, K], "SSRS", "bmm") as [vb, vi, vj, vk]: + with T.init(): + Z[vb, vi, vk] = 0. + Z[vb, vi, vk] = Z[vb, vi, vk] + X[vb, vi, vj] * Y[vb, vj, vk] + + +@T.prim_func +def sddmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(m) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, K), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, J), "float32") + + with T.iter([I, J, K], "SSR", "sddmm") as [vi, vj, vk]: + with T.init(): + C[vi, vj] = 0. + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def fused_sddmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(m) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, K), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, J), "float32") + + with T.iter([T.fuse(I, J), K], "SSR", "sddmm") as [vi, vj, vk]: + with T.init(): + C[vi, vj] = 0. + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k: T.handle, indices_k: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(M) + J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), "int32") + K = T.sparse_variable(J, (N2, nnz_k), (indptr_k, indices_k), "int32") + A = T.match_sparse_buffer(a, (I, J, K), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") + + with T.iter([I, J, K], "SRR", "square_sum") as [vi, vj, vk]: + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vj, vk] + + +@T.prim_func +def square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k0: T.handle, indices_k0: T.handle, indptr_k1: T.handle, indices_k1: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32): + # Used only for testing `GetIndicesRange()`. + # Currently it is ensured that `indptr_k0` is the same as `indptr_k1`, and `indices_k0` is the + # same as `indices_k1`. + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(M) + J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), "int32") + K0 = T.sparse_variable(J, (N2, nnz_k), (indptr_k0, indices_k0), "int32") + K1 = T.sparse_variable(J, (N2, nnz_k), (indptr_k1, indices_k1), "int32") + A = T.match_sparse_buffer(a, (I, J, K0), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") + + with T.iter([I, J, K1], "SRR", "square_sum") as [vi, vj, vk]: + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vj, vk] + + +@T.prim_func +def rgcn_forward( + etype: T.handle, + w: T.handle, + x: T.handle, + y: T.handle, + indptr: T.handle, + indices: T.handle, + n: T.int32, + r: T.int32, + feat_size: T.int32, + nnz: T.int32 +): + I = T.dense_fixed(n) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + R = T.dense_fixed(r) + F_in = T.dense_fixed(feat_size) + F_out = T.dense_fixed(feat_size) + E = T.match_sparse_buffer(etype, (I, J), "int32") + W = T.match_sparse_buffer(w, (R, F_out, F_in), "float32") + X = T.match_sparse_buffer(x, (T.dense(J), F_in), "float32") + Y = T.match_sparse_buffer(y, (I, F_out), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + with T.iter([I, F_out, J, F_in], "SSRR", "rgcn-forward") as [ + vi, vout, vj, vin, + ]: + with T.init(): + Y[vi, vout] = 0. + Y[vi, vout] = Y[vi, vout] + W[E[vi, vj], vout, vin] * X[vj, vin] diff --git a/tests/python/sparsetir/test_tir_sparse_correctness.py b/tests/python/sparsetir/test_tir_sparse_correctness.py index a69170179dc6..506bc2998248 100644 --- a/tests/python/sparsetir/test_tir_sparse_correctness.py +++ b/tests/python/sparsetir/test_tir_sparse_correctness.py @@ -14,256 +14,193 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from ctypes import c_float import tvm import tvm.testing -from tvm.runtime.ndarray import device import tvm.tir as tir import scipy.sparse as sp import numpy as np from tvm.script import tir as T - - -@T.prim_func -def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, NNZ: T.int32) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_data = T.match_buffer(a, (NNZ,), "float32") - B = T.match_buffer(b, (N * K,), "float32") - C = T.match_buffer(c, (M * K,), "float32") - A_indptr = T.match_buffer(indptr, (M + 1,), "int32") - A_indices = T.match_buffer(indices, (NNZ,), "int32") - for i, k in T.grid(M, K): - with T.block("spmm_outer"): - vi, vk = T.axis.remap("SS", [i, k]) - with T.init(): - C[vi * K + vk] = 0. - for j in T.serial(0, A_indptr[vi + 1] - A_indptr[vi]): - with T.block("spmm_inner"): - T.block_attr({"sparse": True}) - vj = T.axis.R(NNZ, j + A_indptr[vi]) - C[vi * K + vk] = C[vi * K + vk] + \ - A_data[vj] * B[A_indices[vj] * K + vk] - - -@T.prim_func -def bsrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, MB: T.int32, NB: T.int32, K: T.int32, BLOCK_SIZE: T.int32, NNZB: T.int32) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_data = T.match_buffer(a, (NNZB * BLOCK_SIZE * BLOCK_SIZE), "float32") - B = T.match_buffer(b, (NB * BLOCK_SIZE * K,), "float32") - C = T.match_buffer(c, (MB * BLOCK_SIZE * K,), "float32") - A_indptr = T.match_buffer(indptr, (MB + 1,), "int32") - A_indices = T.match_buffer(indices, (NNZB,), "int32") - for io, ii, ji, k in T.grid(MB, BLOCK_SIZE, BLOCK_SIZE, K): - with T.block("spmm_outer"): - vio, vii, vji, vk = T.axis.remap("SSSS", [io, ii, ji, k]) - with T.init(): - C[(vio * BLOCK_SIZE + vii) * K + vk] = 0. - for jo in T.serial(0, A_indptr[vio + 1] - A_indptr[vio]): - with T.block("spmm_inner"): - T.block_attr({"sparse": True}) - vjo = T.axis.R(NNZB, jo + A_indptr[vio]) - C[(vio * BLOCK_SIZE + vii) * K + vk] = C[(vio * BLOCK_SIZE + vii) * K + vk] + A_data[( - vjo * BLOCK_SIZE + vii) * BLOCK_SIZE + vji] * B[(A_indices[vjo] * BLOCK_SIZE + vji) * K + vk] - - -@T.prim_func -def ellmm_tir(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, NNZ_COLS: T.int32) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_data = T.match_buffer(a, (M * NNZ_COLS,), "float32") - B = T.match_buffer(b, (N * K,), "float32") - C = T.match_buffer(c, (M * K,), "float32") - A_indices = T.match_buffer(indices, (M * NNZ_COLS,), "int32") - for i, j, k in T.grid(M, NNZ_COLS, K): - with T.block("spmm"): - T.block_attr({"sparse": True}) - vi, vj, vk = T.axis.remap("SRS", [i, j, k]) - with T.init(): - C[vi * K + vk] = 0. - C[vi * K + vk] = C[vi * K + vk] + A_data[vi * NNZ_COLS + vj] * \ - B[A_indices[vi * NNZ_COLS + vj] * K + vk] - - -@T.prim_func -def sddmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, NNZ: T.int32) -> None: - T.func_attr({"global_symbol": "main", "tir.noalis": True}) - A = T.match_buffer(a, (M * K,), "float32") - B = T.match_buffer(b, (N * K,), "float32") - C_data = T.match_buffer(c, (NNZ,), "float32") - C_indptr = T.match_buffer(indptr, (M + 1,), "int32") - C_indices = T.match_buffer(indices, (NNZ,), "int32") - for ij, k in T.grid(NNZ, K): - with T.block("sddmm"): - T.block_attr({"sparse": True}) - vij, vk = T.axis.remap("SR", [ij, k]) - T.reads([A[0: M * K], B[0: N * K], C_data[vij], C_indices[vij], C_indptr[0: M + 1]]) - T.writes([C_data[vij]]) - with T.init(): - C_data[vij] = 0. - C_data[vij] = C_data[vij] + \ - A[(T.upper_bound(C_indptr.data, vij, 0, M + 1) - 1) * K + vk] * B[C_indices[vij] * K + vk] - - -@T.prim_func -def bmm_tir(a: T.handle, b: T.handle, c: T.handle, - indptr_i: T.handle, indptr_j: T.handle, indptr_k: T.handle, - indptr_ij: T.handle, indptr_jk: T.handle, indptr_ik: T.handle, - BATCH: T.int32, - NNZ_IJ: T.int32, NNZ_JK: T.int32, NNZ_IK: T.int32) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = T.match_buffer(a, (NNZ_IJ,), "float32") - B = T.match_buffer(b, (NNZ_JK,), "float32") - C = T.match_buffer(c, (NNZ_IK,), "float32") - indptr_I = T.match_buffer(indptr_i, (BATCH + 1,), "int32") - indptr_J = T.match_buffer(indptr_j, (BATCH + 1,), "int32") - indptr_K = T.match_buffer(indptr_k, (BATCH + 1,), "int32") - indptr_IJ = T.match_buffer(indptr_ij, (BATCH + 1,), "int32") - indptr_JK = T.match_buffer(indptr_jk, (BATCH + 1,), "int32") - indptr_IK = T.match_buffer(indptr_ik, (BATCH + 1,), "int32") - for b in T.grid(BATCH): - with T.block("bmm_outer"): - T.block_attr({"sparse": True}) - vb = T.axis.S(BATCH, b) - with T.init(): - T.evaluate(1) - for i, j, k in T.grid(indptr_I[vb + 1] - indptr_I[vb], indptr_J[vb + 1] - indptr_J[vb], indptr_K[vb + 1] - indptr_K[vb]): - with T.block("bmm_inner"): - T.block_attr({"sparse": True}) - vi, vj, vk = T.axis.remap("SRS", [i, j, k]) - with T.init(): - C[indptr_IK[vb] + vi * (indptr_K[vb + 1] - indptr_K[vb]) + vk] = 0. - C[indptr_IK[vb] + vi * (indptr_K[vb + 1] - indptr_K[vb]) + vk] = C[indptr_IK[vb] + vi * (indptr_K[vb + 1] - indptr_K[vb]) + vk] +\ - A[indptr_IJ[vb] + vi * (indptr_J[vb + 1] - indptr_J[vb]) + vj] * \ - B[indptr_JK[vb] + vj * (indptr_K[vb + 1] - indptr_K[vb]) + vk] +from lowered_tir import * def test_csrmm(): - # generate random input - m = 4096 - n = 4096 - k = 256 - A = sp.random(m, n, dtype="float32", density=0.0125, format='csr') - nnz = A.nnz - x = np.random.rand(n, k).astype("float32") + A = sp.random(512, 512, dtype="float32", density=0.0125, format="csr") + x = np.random.rand(512, 128).astype("float32") y_ground_truth = A * x - y = np.zeros((m * k,)).astype("float32") - - # specialize function - _, _, _, _, _, M, N, K, NNZ = csrmm_tir.params - sch = tir.Schedule( - csrmm_tir.specialize( - {M: m, N: n, K: k, NNZ: nnz} - ) - ) - blk_outer = sch.get_block("spmm_outer") - i, k = sch.get_loops(blk_outer) - sch.bind(i, "blockIdx.x") - sch.bind(k, "threadIdx.x") + y = np.zeros((512, 128)).astype("float32") - # convert numpy tensor to tvm ndarray - A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=tvm.cuda(0)) - A_indices = tvm.nd.array(A.indices.astype("int32"), device=tvm.cuda(0)) - A_data = tvm.nd.array(A.data.astype("float32"), device=tvm.cuda(0)) - X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0)) - Y_nd = tvm.nd.array(y, device=tvm.cuda(0)) + n, m, k, nnz = lowered_csrmm.params[-4:] + f = tvm.build(lowered_csrmm.specialize({n: 512, m: 512, k: 128, nnz: A.nnz}), target="llvm") - # build function - f = tvm.build(sch.mod, target='cuda') + ctx = tvm.cpu(0) + A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=ctx) + A_indices = tvm.nd.array(A.indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(A.data.astype("float32"), device=ctx) + X_nd = tvm.nd.array(x.reshape(-1), device=ctx) + Y_nd = tvm.nd.array(y.reshape(-1), device=ctx) f(A_data, X_nd, Y_nd, A_indptr, A_indices) + tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5) - # assertion - tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5) + +def test_csr_reduce(): + A = sp.random(128, 128, dtype="float32", density=0.0125, format="csr") + b_ground_truth = np.array(np.sum(A, axis=1)) + b = np.zeros((128,)).astype("float32") + + n, m, nnz = lowered_csr_reduce.params[-3:] + f = tvm.build(lowered_csr_reduce.specialize({n: 128, m: 128, nnz: A.nnz}), target="llvm") + + ctx = tvm.cpu(0) + A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=ctx) + A_indices = tvm.nd.array(A.indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(A.data.astype("float32"), device=ctx) + B_nd = tvm.nd.array(b, device=ctx) + f(A_data, B_nd, A_indptr, A_indices) + tvm.testing.assert_allclose(b_ground_truth.reshape(-1), B_nd.numpy(), rtol=1e-5, atol=1e-5) + + +def test_csr_element_wise(): + A = sp.random(128, 128, dtype="float32", density=0.0125, format="csr") + b_ground_truth = A * 2.5 + b = np.zeros((A.nnz,)).astype("float32") + + m, n, nnz = lowered_csr_element_wise.params[-3:] + f = tvm.build(lowered_csr_element_wise.specialize({m: 128, n: 128, nnz: A.nnz}), target="llvm") + + ctx = tvm.cpu(0) + A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=ctx) + A_indices = tvm.nd.array(A.indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(A.data.astype("float32"), device=ctx) + B_nd = tvm.nd.array(b, device=ctx) + f(A_data, B_nd, A_indptr, A_indices) + tvm.testing.assert_allclose(b_ground_truth.data.reshape(-1), B_nd.numpy(), rtol=1e-5, atol=1e-5) def test_bsrmm(): - # generate random input - block_size = 1 - mb = 64 - nb = 64 - k = 256 - m = mb * block_size + block_size = 16 + nb = 32 + mb = 32 + feat_size = 256 n = nb * block_size - A_block = sp.random(mb, nb, dtype="float32", density=0.05, format='csr') + m = mb * block_size + + A_block = sp.random(mb, nb, dtype="float32", density=0.05, format="csr") indptr = A_block.indptr indices = A_block.indices nnzb = A_block.nnz data = np.random.rand(nnzb, block_size, block_size) - A = sp.bsr_matrix((data, indices, indptr), shape=(m, n)) - x = np.random.rand(n, k).astype("float32") + A = sp.bsr_matrix((data, indices, indptr), shape=(n, m)) + x = np.random.rand(m, feat_size).astype("float32") y_ground_truth = A * x - y = np.zeros((m * k,)).astype("float32") - - # specialize function - _, _, _, _, _, MB, NB, K, BLOCK_SIZE, NNZB = bsrmm_tir.params - sch = tir.Schedule( - bsrmm_tir.specialize( - {MB: mb, NB: nb, K: k, BLOCK_SIZE: block_size, NNZB: nnzb} - ) + y = np.zeros((n * feat_size,)).astype("float32") + + v_nb, v_mb, v_nnzb, v_blk, v_feat_size = lowered_bsrmm.params[-5:] + f = tvm.build( + lowered_bsrmm.specialize( + {v_nb: nb, v_mb: mb, v_nnzb: nnzb, v_blk: block_size, v_feat_size: feat_size} + ), + target="llvm", ) - blk_outer = sch.get_block("spmm_outer") - io, ii, ji, k = sch.get_loops(blk_outer) - sch.unroll(ii) - sch.unroll(ji) - sch.bind(io, "blockIdx.x") - sch.bind(k, "threadIdx.x") - - # convert numpy tensor to tvm ndarray - A_indptr = tvm.nd.array(indptr.astype("int32"), device=tvm.cuda(0)) - A_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0)) - A_data = tvm.nd.array( - data.reshape(-1).astype("float32"), device=tvm.cuda(0)) - X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0)) - Y_nd = tvm.nd.array(y, device=tvm.cuda(0)) - # build function - f = tvm.build(sch.mod, target="cuda") + ctx = tvm.cpu(0) + A_indptr = tvm.nd.array(indptr.astype("int32"), device=ctx) + A_indices = tvm.nd.array(indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(data.reshape(-1).astype("float32"), device=ctx) + X_nd = tvm.nd.array(x.reshape(-1), device=ctx) + Y_nd = tvm.nd.array(y, device=ctx) f(A_data, X_nd, Y_nd, A_indptr, A_indices) - - # assertion - tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5) + tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5) def test_ellmm(): + nnz_cols = 4 + nb = 64 + mb = 64 + feat_size = 1024 + nnz = nb * nnz_cols + block_size = 16 + n = nb * block_size + m = mb * block_size + + rng = np.random.default_rng() + indptr = np.arange(0, (nb + 1) * nnz_cols, nnz_cols) + indices = np.array([rng.choice(mb, size=nnz_cols, replace=False) for i in range(nb)]) + order = indices.argsort(axis=1) + indices = np.array([indices[i, order[i]] for i in range(0, nb)]).reshape(-1) + data = np.random.rand(nnz, block_size, block_size) + A = sp.bsr_matrix((data, indices, indptr), shape=(n, m)) + x = np.random.rand(m, feat_size).astype("float32") + y_ground_truth = A * x + y = np.zeros((n * feat_size,)).astype("float32") + + v_nb, v_mb, v_feat_size, v_col, v_blk = lowered_ellmm.params[-5:] + f = tvm.build( + lowered_ellmm.specialize( + { + v_nb: nb, + v_mb: mb, + v_feat_size: feat_size, + v_col: nnz_cols, + v_blk: block_size, + } + ), + target="llvm", + ) + + ctx = tvm.cpu(0) + A_indices = tvm.nd.array(indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(data.reshape(-1).astype("float32"), device=ctx) + X_nd = tvm.nd.array(x.reshape(-1), device=ctx) + Y_nd = tvm.nd.array(y, device=ctx) + f(A_data, X_nd, Y_nd, A_indices) + tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5) + + +def test_sddmm(): # generate random input - nnz_cols = 64 m = 4096 n = 4096 k = 256 - nnz = nnz_cols * m - indptr = np.arange(0, (m + 1) * nnz_cols, nnz_cols) - indices = np.random.randint(0, n, size=(nnz,)) - data = np.random.rand(nnz) - A = sp.csr_matrix((data, indices, indptr), shape=(m, n)) - x = np.random.rand(n, k).astype("float32") - y_ground_truth = A * x - y = np.zeros((m * k,)).astype("float32") + C = sp.random(m, n, dtype="float32", density=0.0125, format='csr') + indptr = C.indptr + indices = C.indices + C_coo = C.tocoo() + nnz = C.nnz + x = np.random.rand(m, k).astype("float32") + y = np.random.rand(n, k).astype("float32") + z_ground_truth = np.matmul(x, y.transpose())[C_coo.row, C_coo.col] + z = np.zeros((nnz,)).astype("float32") + # specialize function - _, _, _, _, M, N, K, NNZ_COLS = ellmm_tir.params + _, _, _, _, _, M, N, K, NNZ = lowered_sddmm.params sch = tir.Schedule( - ellmm_tir.specialize( - {M: m, N: n, K: k, NNZ_COLS: nnz_cols} + lowered_sddmm.specialize( + {M: m, N: n, K: k, NNZ: nnz} ) ) - blk = sch.get_block("spmm") - i, j, k = sch.get_loops(blk) + blk_outer = sch.get_block("sddmm0") + blk_inner = sch.get_block("sddmm1") + i, = sch.get_loops(blk_outer) + _, k = sch.get_loops(blk_inner) sch.bind(i, "blockIdx.x") sch.bind(k, "threadIdx.x") - sch.unroll(j) # convert numpy tensor to tvm ndarray - A_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0)) - A_data = tvm.nd.array(data.astype("float32"), device=tvm.cuda(0)) + C_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0)) + C_indptr = tvm.nd.array(indptr.astype("int32"), device=tvm.cuda(0)) X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0)) - Y_nd = tvm.nd.array(y, device=tvm.cuda(0)) + Y_nd = tvm.nd.array(y.reshape(-1), device=tvm.cuda(0)) + C_data = tvm.nd.array(z, device=tvm.cuda(0)) # build function - f = tvm.build(sch.mod, target="cuda") - f(A_data, X_nd, Y_nd, A_indices) + f = tvm.build(sch.mod['main'], target="cuda") + f(X_nd, Y_nd, C_data, C_indptr, C_indices) # assertion - tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5) + tvm.testing.assert_allclose(z_ground_truth, C_data.numpy(), rtol=1e-5) -def test_sddmm(): +def test_sddmm_fuse(): # generate random input m = 4096 n = 4096 @@ -279,15 +216,16 @@ def test_sddmm(): z = np.zeros((nnz,)).astype("float32") # specialize function - _, _, _, _, _, M, N, K, NNZ = sddmm_tir.params + _, _, _, _, _, M, N, K, NNZ = lowered_sddmm_fuse.params sch = tir.Schedule( - sddmm_tir.specialize( + lowered_sddmm_fuse.specialize( {M: m, N: n, K: k, NNZ: nnz} ) ) - blk = sch.get_block("sddmm") - ij, k = sch.get_loops(blk) - sch.bind(ij, "blockIdx.x") + blk = sch.get_block("sddmm0") + i, j, k = sch.get_loops(blk) + sch.unroll(i) + sch.bind(j, "blockIdx.x") sch.bind(k, "threadIdx.x") # convert numpy tensor to tvm ndarray @@ -320,6 +258,9 @@ def test_bmm(): indptr_nm = np.concatenate(([0], nm_arr)).cumsum() indptr_mk = np.concatenate(([0], mk_arr)).cumsum() indptr_nk = np.concatenate(([0], nk_arr)).cumsum() + nnz_i = indptr_n[-1] + nnz_j = indptr_m[-1] + nnz_k = indptr_k[-1] nnz_ij = indptr_nm[-1] nnz_jk = indptr_mk[-1] nnz_ik = indptr_nk[-1] @@ -337,15 +278,15 @@ def test_bmm(): c_flatten = np.concatenate([C.flatten() for C in Cs], 0) # specialize function - _, _, _, _, _, _, _, _, _, BATCH, NNZ_IJ, NNZ_JK, NNZ_IK = bmm_tir.params + _, _, _, _, _, _, _, _, _, BATCH, NNZ_I, NNZ_J, NNZ_K, NNZ_IJ, NNZ_JK, NNZ_IK = lowered_bmm.params sch = tir.Schedule( - bmm_tir.specialize({ - BATCH: batch_size, NNZ_IJ: nnz_ij, NNZ_JK: nnz_jk, NNZ_IK: nnz_ik + lowered_bmm.specialize({ + BATCH: batch_size, NNZ_I: nnz_i, NNZ_J: nnz_j, NNZ_K: nnz_k, NNZ_IJ: nnz_ij, NNZ_JK: nnz_jk, NNZ_IK: nnz_ik }) ) - bmm_outer = sch.get_block("bmm_outer") + bmm_outer = sch.get_block("bmm0") b, = sch.get_loops(bmm_outer) - bmm_inner = sch.get_block("bmm_inner") + bmm_inner = sch.get_block("bmm1") i, j, k = sch.get_loops(bmm_inner) sch.reorder(i, k, j) io, ii = sch.split(i, [None, 32]) @@ -369,16 +310,94 @@ def test_bmm(): # build function f = tvm.build(sch.mod["main"], target="cuda") - print(f.imported_modules[0].get_source()) f(A_nd, B_nd, C_nd, indptr_n_nd, indptr_m_nd, indptr_k_nd, indptr_nm_nd, indptr_mk_nd, indptr_nk_nd) # assertion tvm.testing.assert_allclose(C_nd.numpy(), c_flatten, rtol=1e-5) +def test_square_sum(): + density = 0.0125 + M = N1 = N2 = 128 + A_J = sp.random(M, N1, dtype="float32", density=1 - (1 - density) ** N2, format="csr") + indptr_j = A_J.indptr + indices_j = A_J.indices + nnz_j = A_J.nnz + A_K = sp.random(nnz_j, N2, dtype="float32", density=density, format="csr") + indptr_k = A_K.indptr + indices_k = A_K.indices + nnz_k = A_K.nnz + data = A_K.data + + b_ij = np.asarray(A_K.sum(axis=1)).squeeze() + A_J = sp.csr_matrix((b_ij, indices_j, indptr_j), shape=(M, N1)) + b_ground_truth = np.asarray(A_J.sum(axis=1)).squeeze() + b = np.zeros((M,)).astype("float32") + + v_nnz_j, v_nnz_k, v_M, v_N1, v_N2 = lowered_square_sum.params[-5:] + f = tvm.build(lowered_square_sum.specialize( + {v_nnz_j: nnz_j, v_nnz_k: nnz_k, v_M: M, v_N1: N1, v_N2: N2}), target="llvm") + + ctx = tvm.cpu(0) + A_data = tvm.nd.array(data.astype("float32"), device=ctx) + A_indptr_j = tvm.nd.array(indptr_j.astype("int32"), device=ctx) + A_indices_j = tvm.nd.array(indices_j.astype("int32"), device=ctx) + A_indptr_k = tvm.nd.array(indptr_k.astype("int32"), device=ctx) + A_indices_k = tvm.nd.array(indices_k.astype("int32"), device=ctx) + B_data = tvm.nd.array(b.astype("float32"), device=ctx) + f(A_data, B_data, A_indptr_j, A_indices_j, A_indptr_k, A_indices_k) + + tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5) + + +def test_square_sum_two_K(): + sch = tir.Schedule(lowered_square_sum_two_K, debug_mask="all") + i, = sch.get_loops(sch.get_block("square_sum_2")) + sch.bind(i, "threadIdx.x") + + density = 0.0125 + M = N1 = N2 = 128 + A_J = sp.random(M, N1, dtype="float32", density=1 - (1 - density) ** N2, format="csr") + indptr_j = A_J.indptr + indices_j = A_J.indices + nnz_j = A_J.nnz + A_K = sp.random(nnz_j, N2, dtype="float32", density=density, format="csr") + indptr_k = A_K.indptr + indices_k = A_K.indices + nnz_k = A_K.nnz + data = A_K.data + + b_ij = np.asarray(A_K.sum(axis=1)).squeeze() + A_J = sp.csr_matrix((b_ij, indices_j, indptr_j), shape=(M, N1)) + b_ground_truth = np.asarray(A_J.sum(axis=1)).squeeze() + b = np.zeros((M,)).astype("float32") + + v_nnz_j, v_nnz_k, v_M, v_N1, v_N2 = sch.mod["main"].params[-5:] + f = tvm.build(sch.mod["main"].specialize( + {v_nnz_j: nnz_j, v_nnz_k: nnz_k, v_M: M, v_N1: N1, v_N2: N2}), target="cuda") + + ctx = tvm.device("cuda") + A_data = tvm.nd.array(data.astype("float32"), device=ctx) + A_indptr_j = tvm.nd.array(indptr_j.astype("int32"), device=ctx) + A_indices_j = tvm.nd.array(indices_j.astype("int32"), device=ctx) + A_indptr_k0 = tvm.nd.array(indptr_k.astype("int32"), device=ctx) + A_indices_k0 = tvm.nd.array(indices_k.astype("int32"), device=ctx) + A_indptr_k1 = tvm.nd.array(indptr_k.astype("int32"), device=ctx) + A_indices_k1 = tvm.nd.array(indices_k.astype("int32"), device=ctx) + B_data = tvm.nd.array(b.astype("float32"), device=ctx) + f(A_data, B_data, A_indptr_j, A_indices_j, A_indptr_k0, A_indices_k0, A_indptr_k1, A_indices_k1) + + tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5) + + if __name__ == "__main__": test_csrmm() + test_csr_reduce() + test_csr_element_wise() test_bsrmm() test_ellmm() test_sddmm() + test_sddmm_fuse() test_bmm() + test_square_sum() + test_square_sum_two_K() diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py index c436d07d0fa5..e33735cd2dda 100644 --- a/tests/python/sparsetir/test_tir_sparse_lower.py +++ b/tests/python/sparsetir/test_tir_sparse_lower.py @@ -16,501 +16,27 @@ # under the License. import tvm import tvm.testing -import tvm.tir as tir -import scipy.sparse as sp -import numpy as np import pytest -from tvm.script import tir as T - - -@T.prim_func -def csrmm( - a: T.handle, - b: T.handle, - c: T.handle, - indptr: T.handle, - indices: T.handle, - n: T.int32, - m: T.int32, - k: T.int32, - nnz: T.int32, -) -> None: - I = T.dense_fixed(n) - J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") - K = T.dense_fixed(k) - A = T.match_sparse_buffer(a, (I, J), "float32") - B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") - C = T.match_sparse_buffer(c, (I, K), "float32") - with T.iter([I, K, J], "SSR", "csrmm") as [vi, vk, vj]: - with T.init(): - C[vi, vk] = 0.0 - C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] - - -@T.prim_func -def csrmm_dense_iter( - a: T.handle, - b: T.handle, - c: T.handle, - indptr: T.handle, - indices: T.handle, - n: T.int32, - m: T.int32, - k: T.int32, - nnz: T.int32, -) -> None: - I = T.dense_fixed(n) - J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") - K = T.dense_fixed(k) - A = T.match_sparse_buffer(a, (I, J), "float32") - B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") - C = T.match_sparse_buffer(c, (I, K), "float32") - with T.iter([I, T.dense(J), K], "SRS", "csrmm") as [vi, vj, vk]: - with T.init(): - C[vi, vk] = 0.0 - C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] - - -@T.prim_func -def segment_reduce( - a: T.handle, - b: T.handle, - indptr: T.handle, - n: T.int32, - nnz: T.int32, -) -> None: - I = T.dense_fixed(n) - J = T.dense_variable(I, (100, nnz), indptr, "int32") - A = T.match_sparse_buffer(a, (I, J), "float32") - B = T.match_sparse_buffer(b, (I,), "float32") - with T.iter([I, J], "SR", "segment_reduce") as [vi, vj]: - with T.init(): - B[vi] = 0. - B[vi] = B[vi] + A[vi, vj] - - -@T.prim_func -def lowered_csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, n: T.int32, m: T.int32, k: T.int32, nnz: T.int32) -> None: - A_data = T.match_buffer(a, [nnz], dtype="float32") - B_data = T.match_buffer(b, [m * k], dtype="float32") - C_data = T.match_buffer(c, [n * k], dtype="float32") - J_indptr = T.match_buffer(indptr, [n + 1], dtype="int32") - J_indices = T.match_buffer(indices, [nnz], dtype="int32") - for v_vi, v_vk in T.grid(n, k): - with T.block("csrmm_outer"): - vi, vk = T.axis.remap("SS", [v_vi, v_vk]) - T.reads([J_indptr[0: n + 1], J_indices[0: nnz], - A_data[0: nnz], B_data[0: m * k], C_data[0: n * k]]) - T.writes([C_data[0: n * k]]) - T.block_attr({"sparse": True}) - for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): - with T.block("csrmm"): - vj = T.axis.reduce(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj) - T.reads([J_indptr[0: n + 1], J_indices[0: nnz], - A_data[0: nnz], B_data[0: m * k], C_data[0: n * k]]) - T.writes([C_data[0: n * k]]) - T.block_attr({"sparse": True}) - with T.init(): - C_data[vi * k + vk] = T.float32(0) - C_data[vi * k + vk] = C_data[vi * k + vk] + A_data[J_indptr[vi] + vj] * \ - B_data[J_indices[J_indptr[vi] + vj] * k + vk] - - -@T.prim_func -def csr_reduce( - a: T.handle, - b: T.handle, - indptr: T.handle, - indices: T.handle, - n: T.int32, - m: T.int32, - nnz: T.int32, -) -> None: - I = T.dense_fixed(n) - J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") - A = T.match_sparse_buffer(a, (I, J), "float32") - B = T.match_sparse_buffer(b, (I,), "float32") - with T.iter([I, J], "SR", "csr_reduce") as [vi, vj]: - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] + A[vi, vj] - - -@T.prim_func -def lowered_csr_reduce(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle, n: T.int32, m: T.int32, nnz: T.int32) -> None: - A_data = T.match_buffer(a, [nnz], dtype="float32") - B_data = T.match_buffer(b, [n], dtype="float32") - J_indptr = T.match_buffer(indptr, [n + 1], dtype="int32") - J_indices = T.match_buffer(indices, [nnz], dtype="int32") - for v_vi in T.serial(0, n): - with T.block("csr_reduce_outer"): - vi = T.axis.spatial(n, v_vi) - T.reads([J_indptr[0: n + 1], J_indices[0: nnz], A_data[0: nnz], B_data[0: n]]) - T.writes([B_data[0: n]]) - T.block_attr({"sparse": True}) - for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): - with T.block("csr_reduce"): - vj = T.axis.reduce(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj) - T.reads([J_indptr[0: n + 1], J_indices[0: nnz], A_data[0: nnz], B_data[0: n]]) - T.writes([B_data[0: n]]) - T.block_attr({"sparse": True}) - with T.init(): - B_data[vi] = T.float32(0) - B_data[vi] = B_data[vi] + A_data[J_indptr[vi] + vj] - - -@T.prim_func -def bsrmm( - a: T.handle, - b: T.handle, - c: T.handle, - indptr: T.handle, - indices: T.handle, - nb: T.int32, - mb: T.int32, - nnzb: T.int32, - blk: T.int32, - feat_size: T.int32, -) -> None: - I = T.dense_fixed(nb) - J = T.sparse_variable(I, (mb, nnzb), (indptr, indices), "int32") - BI = T.dense_fixed(blk) - BJ = T.dense_fixed(blk) - F = T.dense_fixed(feat_size) - A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") - B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") - C = T.match_sparse_buffer(c, (I, BI, F), "float32") - - with T.iter([I, BI, BJ, F, J], "SSRSR", "bsrmm") as [ - vi, - vbi, - vbj, - vf, - vj, - ]: - with T.init(): - C[vi, vbi, vf] = 0.0 - C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] - - -@T.prim_func -def lowered_bsrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, nb: T.int32, mb: T.int32, nnzb: T.int32, blk: T.int32, feat_size: T.int32) -> None: - A_data = T.match_buffer(a, [nnzb * blk * blk], dtype="float32") - B_data = T.match_buffer(b, [mb * blk * feat_size], dtype="float32") - C_data = T.match_buffer(c, [nb * blk * feat_size], dtype="float32") - J_indptr = T.match_buffer(indptr, [nb + 1], dtype="int32") - J_indices = T.match_buffer(indices, [nnzb], dtype="int32") - for v_vi, v_vbi, v_vbj, v_vf in T.grid(nb, blk, blk, feat_size): - with T.block("bsrmm_outer"): - vi, vbi, vbj, vf = T.axis.remap("SSRS", [v_vi, v_vbi, v_vbj, v_vf]) - T.reads([J_indptr[0: nb + 1], J_indices[0: nnzb], A_data[0: nnzb * blk * blk], - B_data[0: mb * blk * feat_size], C_data[0: nb * blk * feat_size]]) - T.writes([C_data[0: nb * blk * feat_size]]) - T.block_attr({"sparse": True}) - with T.init(): - C_data[(vi * blk + vbi) * feat_size + vf] = T.float32(0) - for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): - with T.block("bsrmm"): - vj = T.axis.reduce(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj) - T.reads([J_indptr[0: nb + 1], J_indices[0: nnzb], A_data[0: nnzb * blk * blk], - B_data[0: mb * blk * feat_size], C_data[0: nb * blk * feat_size]]) - T.writes([C_data[0: nb * blk * feat_size]]) - T.block_attr({"sparse": True}) - C_data[(vi * blk + vbi) * feat_size + vf] = C_data[(vi * blk + vbi) * feat_size + vf] + A_data[( - (J_indptr[vi] + vj) * blk + vbi) * blk + vbj] * B_data[(J_indices[J_indptr[vi] + vj] * blk + vbj) * feat_size + vf] - - -@T.prim_func -def ellpack_mm( - a: T.handle, - b: T.handle, - c: T.handle, - indices: T.handle, - nb: T.int32, - mb: T.int32, - feat_size: T.int32, - col: T.int32, - blk: T.int32, -) -> None: - I = T.dense_fixed(nb) - J = T.sparse_fixed(I, (mb, col), indices, "int32") - F = T.dense_fixed(feat_size) - BI = T.dense_fixed(blk) - BJ = T.dense_fixed(blk) - A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") - B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") - C = T.match_sparse_buffer(c, (I, BI, F), "float32") - - with T.iter([I, J, BI, BJ, F], "SRSRS", "ellmm") as [ - vi, - vj, - vbi, - vbj, - vf, - ]: - with T.init(): - C[vi, vbi, vf] = 0.0 - C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] - - -@T.prim_func -def lowered_ellpack_mm(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, nb: T.int32, mb: T.int32, feat_size: T.int32, col: T.int32, blk: T.int32) -> None: - A_data = T.match_buffer(a, [nb * col * blk * blk], dtype="float32") - B_data = T.match_buffer(b, [mb * blk * feat_size], dtype="float32") - C_data = T.match_buffer(c, [nb * blk * feat_size], dtype="float32") - J_indices = T.match_buffer(indices, [nb * col], dtype="int32") - # body - # with T.block("root") - for v_vi, v_vj, v_vbi, v_vbj, v_vf in T.grid(nb, col, blk, blk, feat_size): - with T.block("ellmm"): - vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf]) - T.reads([J_indices[0 : nb * col], A_data[0 : nb * col * blk * blk], B_data[0 : mb * blk * feat_size], C_data[0 : nb * blk * feat_size]]) - T.writes([C_data[0 : nb * blk * feat_size]]) - T.block_attr({"sparse":True}) - with T.init(): - C_data[(vi * blk + vbi) * feat_size + vf] = T.float32(0) - C_data[(vi * blk + vbi) * feat_size + vf] = C_data[(vi * blk + vbi) * feat_size + vf] + A_data[((vi * col + vj) * blk + vbi) * blk + vbj] * B_data[(J_indices[vi * col + vj] * blk + vbj) * feat_size + vf] - - -@T.prim_func -def csr_element_wise( - a: T.handle, - b: T.handle, - indptr: T.handle, - indices: T.handle, - m: T.int32, - n: T.int32, - nnz: T.int32, -) -> None: - I = T.dense_fixed(m) - J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") - A = T.match_sparse_buffer(a, (I, J), "float32") - B = T.match_sparse_buffer(b, (I, J), "float32") - - with T.iter([I, J], "SS", "csr_element_wise") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.5 - - -@T.prim_func -def lowered_csr_element_wise(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, nnz: T.int32) -> None: - A_data = T.match_buffer(a, [nnz], dtype="float32") - B_data = T.match_buffer(b, [nnz], dtype="float32") - J_indptr = T.match_buffer(indptr, [m + 1], dtype="int32") - J_indices = T.match_buffer(indices, [nnz], dtype="int32") - for v_vi in T.serial(0, m): - with T.block("csr_element_wise_outer"): - vi = T.axis.spatial(m, v_vi) - T.reads([J_indptr[0: m + 1], J_indices[0: nnz], A_data[0: nnz]]) - T.writes([B_data[0: nnz]]) - T.block_attr({"sparse": True}) - for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): - with T.block("csr_element_wise"): - vj = T.axis.spatial(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj) - T.reads([J_indptr[0: m + 1], J_indices[0: nnz], A_data[0: nnz]]) - T.writes([B_data[0: nnz]]) - T.block_attr({"sparse": True}) - B_data[J_indptr[vi] + vj] = A_data[J_indptr[vi] + vj] * T.float32(2.5) - -@T.prim_func -def bmm( - x: T.handle, - y: T.handle, - z: T.handle, - indptr_i: T.handle, - indptr_j: T.handle, - indptr_k: T.handle, - indptr_ij: T.handle, - indptr_jk: T.handle, - indptr_ik: T.handle, - batch_size: T.int32, - nnz_i: T.int32, - nnz_j: T.int32, - nnz_k: T.int32, - nnz_ij: T.int32, - nnz_jk: T.int32, - nnz_ik: T.int32 -) -> None: - B = T.dense_fixed(batch_size) - I = T.dense_variable(B, (32768, nnz_i), indptr_i, "int32") - J = T.dense_variable(B, (32768, nnz_j), indptr_j, "int32") - K = T.dense_variable(B, (32768, nnz_k), indptr_k, "int32") - IJ = T.attach_axis(I, J, nnz_ij, indptr_ij, "int32") - JK = T.attach_axis(J, K, nnz_jk, indptr_jk, "int32") - IK = T.attach_axis(I, K, nnz_ik, indptr_ik, "int32") - X = T.match_sparse_buffer(x, (B, I, IJ), "float32") - Y = T.match_sparse_buffer(y, (B, J, JK), "float32") - Z = T.match_sparse_buffer(z, (B, I, IK), "float32") - with T.iter([B, I, J, K], "SSRS", "bmm") as [vb, vi, vj, vk]: - with T.init(): - Z[vb, vi, vk] = 0. - Z[vb, vi, vk] = Z[vb, vi, vk] + X[vb, vi, vk] * Y[vb, vk, vj] - - -@T.prim_func -def sddmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: - I = T.dense_fixed(m) - J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") - K = T.dense_fixed(k) - A = T.match_sparse_buffer(a, (I, K), "float32") - B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") - C = T.match_sparse_buffer(c, (I, J), "float32") - - with T.iter([I, J, K], "SSR", "sddmm") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0. - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] - - -@T.prim_func -def fused_sddmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: - I = T.dense_fixed(m) - J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") - K = T.dense_fixed(k) - A = T.match_sparse_buffer(a, (I, K), "float32") - B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") - C = T.match_sparse_buffer(c, (I, J), "float32") - - with T.iter([T.fuse(I, J), K], "SSR", "sddmm") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0. - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] - - -@T.prim_func -def square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k: T.handle, indices_k: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32): - I = T.dense_fixed(M) - J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), "int32") - K = T.sparse_variable(J, (N2, nnz_k), (indptr_k, indices_k), "int32") - A = T.match_sparse_buffer(a, (I, J, K), "float32") - B = T.match_sparse_buffer(b, (I,), "float32") - - with T.iter([I, J, K], "SRR", "square_sum") as [vi, vj, vk]: - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] + A[vi, vj, vk] - - -@T.prim_func -def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k: T.handle, indices_k: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32) -> None: - A_data = T.match_buffer(a, [nnz_k], dtype="float32") - B_data = T.match_buffer(b, [M], dtype="float32") - J_indptr = T.match_buffer(indptr_j, [M + 1], dtype="int32") - J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32") - K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32") - K_indices = T.match_buffer(indices_k, [nnz_k], dtype="int32") - - for v_vi in T.serial(0, M): - with T.block("square_sum_2"): - vi = T.axis.spatial(M, v_vi) - T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K_indptr[0 : nnz_j + 1], K_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]]) - T.writes([B_data[0 : M]]) - T.block_attr({"sparse":True}) - for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): - with T.block("square_sum_1"): - vj = T.axis.reduce(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj) - T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K_indptr[0 : nnz_j + 1], K_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]]) - T.writes([B_data[0 : M]]) - T.block_attr({"sparse":True}) - with T.init(): - B_data[vi] = T.float32(0) - for v_vk in T.serial(0, K_indptr[J_indptr[v_vi] + v_vj + 1] - K_indptr[J_indptr[v_vi] + v_vj]): - with T.block("square_sum"): - vk = T.axis.reduce(K_indptr[J_indptr[v_vi] + v_vj + 1] - K_indptr[J_indptr[v_vi] + v_vj], v_vk) - T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K_indptr[0 : nnz_j + 1], K_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]]) - T.writes([B_data[0 : M]]) - T.block_attr({"sparse":True}) - B_data[vi] = B_data[vi] + A_data[K_indptr[J_indptr[vi] + vj] + vk] - - -@T.prim_func -def square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k0: T.handle, indices_k0: T.handle, indptr_k1: T.handle, indices_k1: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32): - # Used only for testing `GetIndicesRange()`. - # Currently it is ensured that `indptr_k0` is the same as `indptr_k1`, and `indices_k0` is the - # same as `indices_k1`. - I = T.dense_fixed(M) - J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), "int32") - K0 = T.sparse_variable(J, (N2, nnz_k), (indptr_k0, indices_k0), "int32") - K1 = T.sparse_variable(J, (N2, nnz_k), (indptr_k1, indices_k1), "int32") - A = T.match_sparse_buffer(a, (I, J, K0), "float32") - B = T.match_sparse_buffer(b, (I,), "float32") - - with T.iter([I, J, K1], "SRR", "square_sum") as [vi, vj, vk]: - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] + A[vi, vj, vk] - - -@T.prim_func -def lowered_square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k0: T.handle, indices_k0: T.handle, indptr_k1: T.handle, indices_k1: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32) -> None: - A_data = T.match_buffer(a, [nnz_k], dtype="float32") - B_data = T.match_buffer(b, [M], dtype="float32") - J_indptr = T.match_buffer(indptr_j, [M + 1], dtype="int32") - J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32") - K0_indptr = T.match_buffer(indptr_k0, [nnz_j + 1], dtype="int32") - K0_indices = T.match_buffer(indices_k0, [nnz_k], dtype="int32") - K1_indptr = T.match_buffer(indptr_k1, [nnz_j + 1], dtype="int32") - K1_indices = T.match_buffer(indices_k1, [nnz_k], dtype="int32") - - for v_vi in T.serial(0, M): - with T.block("square_sum_2"): - vi = T.axis.spatial(M, v_vi) - T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]]) - T.writes([B_data[0 : M]]) - T.block_attr({"sparse":True}) - for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): - with T.block("square_sum_1"): - vj = T.axis.reduce(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj) - T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]]) - T.writes([B_data[0 : M]]) - T.block_attr({"sparse":True}) - with T.init(): - B_data[vi] = T.float32(0) - for v_vk in T.serial(0, K1_indptr[J_indptr[v_vi] + v_vj + 1] - K1_indptr[J_indptr[v_vi] + v_vj]): - with T.block("square_sum"): - vk = T.axis.reduce(K1_indptr[J_indptr[v_vi] + v_vj + 1] - K1_indptr[J_indptr[v_vi] + v_vj], v_vk) - T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]]) - T.writes([B_data[0 : M]]) - T.block_attr({"sparse":True}) - B_data[vi] = B_data[vi] + A_data[T.tvm_lower_bound(K0_indices.data, K1_indices[K1_indptr[J_indptr[vi] + vj] + vk], K0_indptr[J_indptr[vi] + vj], K0_indptr[J_indptr[vi] + vj + 1], dtype="int32")] +from lowered_tir import * +from sparse_tir_scripts import * def test_csrmm(): mod = tvm.IRModule.from_expr(csrmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) - print(mod["main"].script()) tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True) - A = sp.random(512, 512, dtype="float32", density=0.0125, format="csr") - x = np.random.rand(512, 128).astype("float32") - y_ground_truth = A * x - y = np.zeros((512, 128)).astype("float32") - - n, m, k, nnz = mod["main"].params[-4:] - f = tvm.build(mod["main"].specialize({n: 512, m: 512, k: 128, nnz: A.nnz}), target="llvm") - - ctx = tvm.cpu(0) - A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=ctx) - A_indices = tvm.nd.array(A.indices.astype("int32"), device=ctx) - A_data = tvm.nd.array(A.data.astype("float32"), device=ctx) - X_nd = tvm.nd.array(x.reshape(-1), device=ctx) - Y_nd = tvm.nd.array(y.reshape(-1), device=ctx) - f(A_data, X_nd, Y_nd, A_indptr, A_indices) - tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5) - -@pytest.mark.skip(reason="Under implementation") def test_csrmm_dense_iter(): mod = tvm.IRModule.from_expr(csrmm_dense_iter) mod = tvm.tir.transform.LowerSparseTIR()(mod) - # tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True) - # Todo + tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm_dense_iter, True) -@pytest.mark.skip(reason="Under implementation") def test_segment_reduce(): mod = tvm.IRModule.from_expr(segment_reduce) mod = tvm.tir.transform.LowerSparseTIR()(mod) - # Todo + tvm.ir.assert_structural_equal(mod["main"], lowered_segment_reduce, True) def test_csr_reduce(): @@ -518,108 +44,17 @@ def test_csr_reduce(): mod = tvm.tir.transform.LowerSparseTIR()(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_csr_reduce, True) - A = sp.random(128, 128, dtype="float32", density=0.0125, format="csr") - b_ground_truth = np.array(np.sum(A, axis=1)) - b = np.zeros((128,)).astype("float32") - - n, m, nnz = csr_reduce.params[-3:] - f = tvm.build(mod["main"].specialize({n: 128, m: 128, nnz: A.nnz}), target="llvm") - - ctx = tvm.cpu(0) - A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=ctx) - A_indices = tvm.nd.array(A.indices.astype("int32"), device=ctx) - A_data = tvm.nd.array(A.data.astype("float32"), device=ctx) - B_nd = tvm.nd.array(b, device=ctx) - f(A_data, B_nd, A_indptr, A_indices) - tvm.testing.assert_allclose(b_ground_truth.reshape(-1), B_nd.numpy(), rtol=1e-5, atol=1e-5) - def test_bsrmm(): mod = tvm.IRModule.from_expr(bsrmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_bsrmm, True) - block_size = 16 - nb = 32 - mb = 32 - feat_size = 256 - n = nb * block_size - m = mb * block_size - - A_block = sp.random(mb, nb, dtype="float32", density=0.05, format="csr") - indptr = A_block.indptr - indices = A_block.indices - nnzb = A_block.nnz - data = np.random.rand(nnzb, block_size, block_size) - A = sp.bsr_matrix((data, indices, indptr), shape=(n, m)) - x = np.random.rand(m, feat_size).astype("float32") - y_ground_truth = A * x - y = np.zeros((n * feat_size,)).astype("float32") - - v_nb, v_mb, v_nnzb, v_blk, v_feat_size = bsrmm.params[-5:] - f = tvm.build( - mod["main"].specialize( - {v_nb: nb, v_mb: mb, v_nnzb: nnzb, v_blk: block_size, v_feat_size: feat_size} - ), - target="llvm", - ) - - ctx = tvm.cpu(0) - A_indptr = tvm.nd.array(indptr.astype("int32"), device=ctx) - A_indices = tvm.nd.array(indices.astype("int32"), device=ctx) - A_data = tvm.nd.array(data.reshape(-1).astype("float32"), device=ctx) - X_nd = tvm.nd.array(x.reshape(-1), device=ctx) - Y_nd = tvm.nd.array(y, device=ctx) - f(A_data, X_nd, Y_nd, A_indptr, A_indices) - tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5) - def test_ellpack_mm(): - mod = tvm.IRModule.from_expr(ellpack_mm) + mod = tvm.IRModule.from_expr(ellmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_ellpack_mm, True) - - nnz_cols = 4 - nb = 64 - mb = 64 - feat_size = 1024 - nnz = nb * nnz_cols - block_size = 16 - n = nb * block_size - m = mb * block_size - - rng = np.random.default_rng() - indptr = np.arange(0, (nb + 1) * nnz_cols, nnz_cols) - indices = np.array([rng.choice(mb, size=nnz_cols, replace=False) for i in range(nb)]) - order = indices.argsort(axis=1) - indices = np.array([indices[i, order[i]] for i in range(0, nb)]).reshape(-1) - data = np.random.rand(nnz, block_size, block_size) - A = sp.bsr_matrix((data, indices, indptr), shape=(n, m)) - x = np.random.rand(m, feat_size).astype("float32") - y_ground_truth = A * x - y = np.zeros((n * feat_size,)).astype("float32") - - v_nb, v_mb, v_feat_size, v_col, v_blk = ellpack_mm.params[-5:] - f = tvm.build( - mod["main"].specialize( - { - v_nb: nb, - v_mb: mb, - v_feat_size: feat_size, - v_col: nnz_cols, - v_blk: block_size, - } - ), - target="llvm", - ) - - ctx = tvm.cpu(0) - A_indices = tvm.nd.array(indices.astype("int32"), device=ctx) - A_data = tvm.nd.array(data.reshape(-1).astype("float32"), device=ctx) - X_nd = tvm.nd.array(x.reshape(-1), device=ctx) - Y_nd = tvm.nd.array(y, device=ctx) - f(A_data, X_nd, Y_nd, A_indices) - tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5) + tvm.ir.assert_structural_equal(mod["main"], lowered_ellmm, True) def test_csr_element_wise(): @@ -627,43 +62,26 @@ def test_csr_element_wise(): mod = tvm.tir.transform.LowerSparseTIR()(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_csr_element_wise, True) - A = sp.random(128, 128, dtype="float32", density=0.0125, format="csr") - b_ground_truth = A * 2.5 - b = np.zeros((A.nnz,)).astype("float32") - - m, n, nnz = csr_element_wise.params[-3:] - f = tvm.build(mod["main"].specialize({m: 128, n: 128, nnz: A.nnz}), target="llvm") - - ctx = tvm.cpu(0) - A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=ctx) - A_indices = tvm.nd.array(A.indices.astype("int32"), device=ctx) - A_data = tvm.nd.array(A.data.astype("float32"), device=ctx) - B_nd = tvm.nd.array(b, device=ctx) - f(A_data, B_nd, A_indptr, A_indices) - tvm.testing.assert_allclose(b_ground_truth.data.reshape(-1), B_nd.numpy(), rtol=1e-5, atol=1e-5) - @pytest.mark.skip(reason="Under implementation") def test_bmm(): mod = tvm.IRModule.from_expr(bmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) - print(mod["main"].script()) - # TODO + tvm.ir.assert_structural_equal(mod["main"], lowered_bmm) @pytest.mark.skip(reason="Under implementation") def test_sddmm(): mod = tvm.IRModule.from_expr(sddmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) - # TODO + tvm.ir.assert_structural_equal(mod["main"], lowered_sddmm) @pytest.mark.skip(reason="Under implementation") def test_fused_sddmm(): mod = tvm.IRModule.from_expr(fused_sddmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) - print(mod["main"].script()) - # TODO + tvm.ir.assert_structural_equal(mod["main"], lowered_sddmm_fuse) def test_square_sum(): @@ -671,80 +89,12 @@ def test_square_sum(): mod = tvm.tir.transform.LowerSparseTIR()(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum, True) - density = 0.0125 - M = N1 = N2 = 128 - A_J = sp.random(M, N1, dtype="float32", density=1 - (1 - density) ** N2, format="csr") - indptr_j = A_J.indptr - indices_j = A_J.indices - nnz_j = A_J.nnz - A_K = sp.random(nnz_j, N2, dtype="float32", density=density, format="csr") - indptr_k = A_K.indptr - indices_k = A_K.indices - nnz_k = A_K.nnz - data = A_K.data - - b_ij = np.asarray(A_K.sum(axis=1)).squeeze() - A_J = sp.csr_matrix((b_ij, indices_j, indptr_j), shape=(M, N1)) - b_ground_truth = np.asarray(A_J.sum(axis=1)).squeeze() - b = np.zeros((M,)).astype("float32") - - v_nnz_j, v_nnz_k, v_M, v_N1, v_N2 = square_sum.params[-5:] - f = tvm.build(mod["main"].specialize({v_nnz_j: nnz_j, v_nnz_k: nnz_k, v_M: M, v_N1: N1, v_N2: N2}), target="llvm") - - ctx = tvm.cpu(0) - A_data = tvm.nd.array(data.astype("float32"), device=ctx) - A_indptr_j = tvm.nd.array(indptr_j.astype("int32"), device=ctx) - A_indices_j = tvm.nd.array(indices_j.astype("int32"), device=ctx) - A_indptr_k = tvm.nd.array(indptr_k.astype("int32"), device=ctx) - A_indices_k = tvm.nd.array(indices_k.astype("int32"), device=ctx) - B_data = tvm.nd.array(b.astype("float32"), device=ctx) - f(A_data, B_data, A_indptr_j, A_indices_j, A_indptr_k, A_indices_k) - - tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5) - def test_square_sum_two_K(): mod = tvm.IRModule.from_expr(square_sum_two_K) mod = tvm.tir.transform.LowerSparseTIR()(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum_two_K, True) - sch = tir.Schedule(mod, debug_mask="all") - i, = sch.get_loops(sch.get_block("square_sum0")) - sch.bind(i, "threadIdx.x") - - density = 0.0125 - M = N1 = N2 = 128 - A_J = sp.random(M, N1, dtype="float32", density=1 - (1 - density) ** N2, format="csr") - indptr_j = A_J.indptr - indices_j = A_J.indices - nnz_j = A_J.nnz - A_K = sp.random(nnz_j, N2, dtype="float32", density=density, format="csr") - indptr_k = A_K.indptr - indices_k = A_K.indices - nnz_k = A_K.nnz - data = A_K.data - - b_ij = np.asarray(A_K.sum(axis=1)).squeeze() - A_J = sp.csr_matrix((b_ij, indices_j, indptr_j), shape=(M, N1)) - b_ground_truth = np.asarray(A_J.sum(axis=1)).squeeze() - b = np.zeros((M,)).astype("float32") - - v_nnz_j, v_nnz_k, v_M, v_N1, v_N2 = square_sum_two_K.params[-5:] - f = tvm.build(sch.mod["main"].specialize({v_nnz_j: nnz_j, v_nnz_k: nnz_k, v_M: M, v_N1: N1, v_N2: N2}), target="cuda") - - ctx = tvm.device("cuda") - A_data = tvm.nd.array(data.astype("float32"), device=ctx) - A_indptr_j = tvm.nd.array(indptr_j.astype("int32"), device=ctx) - A_indices_j = tvm.nd.array(indices_j.astype("int32"), device=ctx) - A_indptr_k0 = tvm.nd.array(indptr_k.astype("int32"), device=ctx) - A_indices_k0 = tvm.nd.array(indices_k.astype("int32"), device=ctx) - A_indptr_k1 = tvm.nd.array(indptr_k.astype("int32"), device=ctx) - A_indices_k1 = tvm.nd.array(indices_k.astype("int32"), device=ctx) - B_data = tvm.nd.array(b.astype("float32"), device=ctx) - f(A_data, B_data, A_indptr_j, A_indices_j, A_indptr_k0, A_indices_k0, A_indptr_k1, A_indices_k1) - - tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5) - if __name__ == "__main__": test_csrmm() diff --git a/tests/python/sparsetir/test_tir_sparse_tensorize.py b/tests/python/sparsetir/test_tir_sparse_tensorize.py new file mode 100644 index 000000000000..e69de29bb2d1