diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py index 9c6bcc2c2bd9..de2a1c1b49ab 100644 --- a/tests/python/sparsetir/test_tir_sparse_lower.py +++ b/tests/python/sparsetir/test_tir_sparse_lower.py @@ -345,6 +345,52 @@ def bmm( Z[vb, vi, vk] = Z[vb, vi, vk] + X[vb, vi, vk] * Y[vb, vk, vj] +@T.prim_func +def square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k: T.handle, indices_k: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32): + I = T.dense_fixed(M) + J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), "int32") + K = T.sparse_variable(J, (N2, nnz_k), (indptr_k, indices_k), "int32") + A = T.match_sparse_buffer(a, (I, J, K), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") + + with T.iter([I, J, K], "SRR", "square_sum") as [vi, vj, vk]: + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vj, vk] + + +@T.prim_func +def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k: T.handle, indices_k: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32) -> None: + A_data = T.match_buffer(a, [nnz_k], dtype="float32") + B_data = T.match_buffer(b, [M], dtype="float32") + J_indptr = T.match_buffer(indptr_j, [M + 1], dtype="int32") + J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32") + K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32") + K_indices = T.match_buffer(indices_k, [nnz_k], dtype="int32") + + for v_vi in T.serial(0, M): + with T.block("square_sum_2"): + vi = T.axis.spatial(M, v_vi) + T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K_indptr[0 : nnz_j + 1], K_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]]) + T.writes([B_data[0 : M]]) + T.block_attr({"sparse":True}) + for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): + with T.block("square_sum_1"): + vj = T.axis.reduce(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj) + T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K_indptr[0 : nnz_j + 1], K_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]]) + T.writes([B_data[0 : M]]) + T.block_attr({"sparse":True}) + with T.init(): + B_data[vi] = T.float32(0) + for v_vk in T.serial(0, K_indptr[J_indptr[v_vi] + v_vj + 1] - K_indptr[J_indptr[v_vi] + v_vj]): + with T.block("square_sum"): + vk = T.axis.reduce(K_indptr[J_indptr[v_vi] + v_vj + 1] - K_indptr[J_indptr[v_vi] + v_vj], v_vk) + T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K_indptr[0 : nnz_j + 1], K_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]]) + T.writes([B_data[0 : M]]) + T.block_attr({"sparse":True}) + B_data[vi] = B_data[vi] + A_data[K_indptr[J_indptr[vi] + vj] + vk] + + def test_csrmm(): mod = tvm.IRModule.from_expr(csrmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) @@ -372,11 +418,13 @@ def test_csrmm_dense_iter(): mod = tvm.IRModule.from_expr(csrmm_dense_iter) mod = tvm.tir.transform.LowerSparseTIR()(mod) # tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True) + # Todo def test_segment_reduce(): mod = tvm.IRModule.from_expr(segment_reduce) mod = tvm.tir.transform.LowerSparseTIR()(mod) + # Todo def test_csr_reduce(): @@ -512,7 +560,44 @@ def test_csr_element_wise(): def test_bmm(): mod = tvm.IRModule.from_expr(bmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) - print(mod['main'].script()) + # Todo + + +def test_square_sum(): + mod = tvm.IRModule.from_expr(square_sum) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum, True) + + density = 0.0125 + M = N1 = N2 = 128 + A_J = sp.random(M, N1, dtype="float32", density=1 - (1 - density) ** N2, format="csr") + indptr_j = A_J.indptr + indices_j = A_J.indices + nnz_j = A_J.nnz + A_K = sp.random(nnz_j, N2, dtype="float32", density=density, format="csr") + indptr_k = A_K.indptr + indices_k = A_K.indices + nnz_k = A_K.nnz + data = A_K.data + + b_ij = np.asarray(A_K.sum(axis=1)).squeeze() + A_J = sp.csr_matrix((b_ij, indices_j, indptr_j), shape=(M, N1)) + b_ground_truth = np.asarray(A_J.sum(axis=1)).squeeze() + b = np.zeros((M,)).astype("float32") + + v_nnz_j, v_nnz_k, v_M, v_N1, v_N2 = square_sum.params[-5:] + f = tvm.build(mod["main"].specialize({v_nnz_j: nnz_j, v_nnz_k: nnz_k, v_M: M, v_N1: N1, v_N2: N2}), target="llvm") + + ctx = tvm.cpu(0) + A_data = tvm.nd.array(data.astype("float32"), device=ctx) + A_indptr_j = tvm.nd.array(indptr_j.astype("int32"), device=ctx) + A_indices_j = tvm.nd.array(indices_j.astype("int32"), device=ctx) + A_indptr_k = tvm.nd.array(indptr_k.astype("int32"), device=ctx) + A_indices_k = tvm.nd.array(indices_k.astype("int32"), device=ctx) + B_data = tvm.nd.array(b.astype("float32"), device=ctx) + f(A_data, B_data, A_indptr_j, A_indices_j, A_indptr_k, A_indices_k) + + tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5) if __name__ == "__main__": @@ -524,3 +609,4 @@ def test_bmm(): test_ellpack_mm() test_csr_element_wise() test_bmm() + test_square_sum()