Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Axis Dependency Tree aware code-gen and bmm example #28

Merged
merged 12 commits into from
Nov 26, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ TVM_DLL Pass ConvertForLoopsToSerial();
* \brief Lower SparseTIR to TIR.
* \return The pass.
*/
TVM_DLL Pass LowerSparseTIR();
TVM_DLL Pass LowerSparseTIR(AxisTree t);

} // namespace transform
} // namespace tir
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Optional
from . import _ffi_api
from . import function_pass as _fpass
from ..sparse import AxisTree


def Apply(ftransform):
Expand Down Expand Up @@ -751,12 +752,17 @@ def ConvertForLoopsToSerial():
return _ffi_api.ConvertForLoopsToSerial() # type: ignore


def LowerSparseTIR():
def LowerSparseTIR(t: AxisTree):
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
"""Lower SparseTIR to TIR

Parameters
----------
t : AxisTree
The axis dependency tree.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerSparseTIR() # type: ignore
return _ffi_api.LowerSparseTIR(t) # type: ignore
6 changes: 3 additions & 3 deletions src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ Stmt WrapWithRootBlock(Stmt body) {
* \param f The Sparse-TIR primitive function to lower.
* \return lowered primitive function in TIR.
*/
PrimFunc LowerSparseTIR(PrimFunc f) {
PrimFunc LowerSparseTIR(AxisTree t, PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
Expand All @@ -439,9 +439,9 @@ namespace transform {
/*!
* \brief The lowering pass from TIR to Sparse TIR.
*/
Pass LowerSparseTIR() {
Pass LowerSparseTIR(AxisTree t) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return LowerSparseTIR(std::move(f));
return LowerSparseTIR(std::move(t), std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerSparseTIR", {});
}
Expand Down
126 changes: 113 additions & 13 deletions tests/python/sparsetir/test_tir_sparse_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from ctypes import c_float
import tvm
import tvm.testing
from tvm.runtime.ndarray import device
import tvm.tir as tir
import scipy.sparse as sp
Expand Down Expand Up @@ -113,7 +115,40 @@ def sddmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices:
with T.init():
C_data[vij] = 0.
C_data[vij] = C_data[vij] + \
A[T.lower_bound(C_indptr.data, vij, 0, M + 1) * K + vk] * B[C_indices[vij] * K + vk]
A[(T.upper_bound(C_indptr.data, vij, 0, M + 1) - 1) * K + vk] * B[C_indices[vij] * K + vk]


@T.prim_func
def bmm_tir(a: T.handle, b: T.handle, c: T.handle,
indptr_i: T.handle, indptr_j: T.handle, indptr_k: T.handle,
indptr_ij: T.handle, indptr_jk: T.handle, indptr_ik: T.handle,
BATCH: T.int32,
NNZ_IJ: T.int32, NNZ_JK: T.int32, NNZ_IK: T.int32) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, (NNZ_IJ,), "float32")
B = T.match_buffer(b, (NNZ_JK,), "float32")
C = T.match_buffer(c, (NNZ_IK,), "float32")
indptr_I = T.match_buffer(indptr_i, (BATCH + 1,), "int32")
indptr_J = T.match_buffer(indptr_j, (BATCH + 1,), "int32")
indptr_K = T.match_buffer(indptr_k, (BATCH + 1,), "int32")
indptr_IJ = T.match_buffer(indptr_ij, (BATCH + 1,), "int32")
indptr_JK = T.match_buffer(indptr_jk, (BATCH + 1,), "int32")
indptr_IK = T.match_buffer(indptr_ik, (BATCH + 1,), "int32")
for b in T.grid(BATCH):
with T.block("bmm_outer"):
T.block_attr({"sparse": True})
vb = T.axis.S(BATCH, b)
with T.init():
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MasterJH5574 Note that w/o these two lines we cannot identify bmm_outer as a reduction block, any ideas on what should we do to this these?

T.evaluate(1)
for i, j, k in T.grid(indptr_I[vb + 1] - indptr_I[vb], indptr_J[vb + 1] - indptr_J[vb], indptr_K[vb + 1] - indptr_K[vb]):
with T.block("bmm_inner"):
T.block_attr({"sparse": True})
vi, vj, vk = T.axis.remap("SRS", [i, j, k])
with T.init():
C[indptr_IK[vb] + vi * (indptr_K[vb + 1] - indptr_K[vb]) + vk] = 0.
C[indptr_IK[vb] + vi * (indptr_K[vb + 1] - indptr_K[vb]) + vk] = C[indptr_IK[vb] + vi * (indptr_K[vb + 1] - indptr_K[vb]) + vk] +\
A[indptr_IJ[vb] + vi * (indptr_J[vb + 1] - indptr_J[vb]) + vj] * \
B[indptr_JK[vb] + vj * (indptr_K[vb + 1] - indptr_K[vb]) + vk]


def test_csrmm():
Expand Down Expand Up @@ -151,7 +186,7 @@ def test_csrmm():
f(A_data, X_nd, Y_nd, A_indptr, A_indices)

# assertion
assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy())
tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5)


def test_bsrmm():
Expand Down Expand Up @@ -199,7 +234,7 @@ def test_bsrmm():
f(A_data, X_nd, Y_nd, A_indptr, A_indices)

# assertion
assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy())
tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5)


def test_ellmm():
Expand Down Expand Up @@ -240,7 +275,7 @@ def test_ellmm():
f(A_data, X_nd, Y_nd, A_indices)

# assertion
assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy())
tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5)


def test_sddmm():
Expand Down Expand Up @@ -269,8 +304,8 @@ def test_sddmm():
ij, k = sch.get_loops(blk)
sch.bind(ij, "blockIdx.x")
sch.bind(k, "threadIdx.x")
sch.decompose_reduction(blk, k)

print(len(np.unique(indptr)))
# convert numpy tensor to tvm ndarray
C_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0))
C_indptr = tvm.nd.array(indptr.astype("int32"), device=tvm.cuda(0))
Expand All @@ -280,21 +315,86 @@ def test_sddmm():

# build function
f = tvm.build(sch.mod['main'], target="cuda")
# print(f.imported_modules[0].get_source())
f(X_nd, Y_nd, C_data, C_indptr, C_indices)

# assertion
np.allclose(z_ground_truth, C_data.numpy())
tvm.testing.assert_allclose(z_ground_truth, C_data.numpy(), rtol=1e-5)


def test_bmm():
# TODO(zihao)
pass
# generate random input
batch_size = 32
n_arr = np.random.randint(128, 1024, size=(batch_size,)).astype("int32")
m_arr = np.random.randint(128, 1024, size=(batch_size,)).astype("int32")
k_arr = np.random.randint(128, 1024, size=(batch_size,)).astype("int32")
nm_arr = n_arr * m_arr
mk_arr = m_arr * k_arr
nk_arr = n_arr * k_arr
indptr_n = np.concatenate(([0], n_arr)).cumsum()
indptr_m = np.concatenate(([0], m_arr)).cumsum()
indptr_k = np.concatenate(([0], k_arr)).cumsum()
indptr_nm = np.concatenate(([0], nm_arr)).cumsum()
indptr_mk = np.concatenate(([0], mk_arr)).cumsum()
indptr_nk = np.concatenate(([0], nk_arr)).cumsum()
nnz_ij = indptr_nm[-1]
nnz_jk = indptr_mk[-1]
nnz_ik = indptr_nk[-1]
As = [
np.random.rand(n, m).astype("float32") for n, m in zip(n_arr, m_arr)
]
Bs = [
np.random.rand(m, k).astype("float32") for m, k in zip(m_arr, k_arr)
]
Cs = [
np.matmul(A, B) for A, B in zip(As, Bs)
]
A_flatten = np.concatenate([A.flatten() for A in As], 0)
B_flatten = np.concatenate([B.flatten() for B in Bs], 0)
c_flatten = np.concatenate([C.flatten() for C in Cs], 0)

# specialize function
_, _, _, _, _, _, _, _, _, BATCH, NNZ_IJ, NNZ_JK, NNZ_IK = bmm_tir.params
sch = tir.Schedule(
bmm_tir.specialize({
BATCH: batch_size, NNZ_IJ: nnz_ij, NNZ_JK: nnz_jk, NNZ_IK: nnz_ik
})
)
bmm_outer = sch.get_block("bmm_outer")
b, = sch.get_loops(bmm_outer)
bmm_inner = sch.get_block("bmm_inner")
i, j, k = sch.get_loops(bmm_inner)
sch.reorder(i, k, j)
io, ii = sch.split(i, [None, 32])
ko, ki = sch.split(k, [None, 32])
sch.bind(b, "blockIdx.x")
sch.bind(ki, "threadIdx.x")
sch.bind(ii, "threadIdx.y")
sch.decompose_reduction(bmm_inner, j)

# convert numpy tensor to tvm ndarray
dev = tvm.cuda(0)
A_nd = tvm.nd.array(A_flatten, device=dev)
B_nd = tvm.nd.array(B_flatten, device=dev)
C_nd = tvm.nd.array(np.zeros_like(c_flatten), device=dev)
indptr_n_nd = tvm.nd.array(indptr_n.astype("int32"), device=dev)
indptr_m_nd = tvm.nd.array(indptr_m.astype("int32"), device=dev)
indptr_k_nd = tvm.nd.array(indptr_k.astype("int32"), device=dev)
indptr_nm_nd = tvm.nd.array(indptr_nm.astype("int32"), device=dev)
indptr_mk_nd = tvm.nd.array(indptr_mk.astype("int32"), device=dev)
indptr_nk_nd = tvm.nd.array(indptr_nk.astype("int32"), device=dev)

# build function
f = tvm.build(sch.mod["main"], target="cuda")
print(f.imported_modules[0].get_source())
f(A_nd, B_nd, C_nd, indptr_n_nd, indptr_m_nd, indptr_k_nd, indptr_nm_nd, indptr_mk_nd, indptr_nk_nd)

# assertion
tvm.testing.assert_allclose(C_nd.numpy(), c_flatten, rtol=1e-5)


if __name__ == "__main__":
test_csrmm()
test_bsrmm()
test_ellmm()
test_sddmm()
# test_csrmm()
# test_bsrmm()
# test_ellmm()
# test_sddmm()
test_bmm()
40 changes: 34 additions & 6 deletions tests/python/sparsetir/test_tir_sparse_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import scipy.sparse as sp
import numpy as np
from tvm.script import tir as T
from tvm.tir.sparse import AxisTree


@T.prim_func
Expand Down Expand Up @@ -372,7 +373,11 @@ def lowered_csr_element_wise(

def test_csrmm():
mod = tvm.IRModule.from_expr(csrmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
t = AxisTree({
"J": "I",
"J": None
})
mod = tvm.tir.transform.LowerSparseTIR(t)(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True)

A = sp.random(512, 512, dtype="float32", density=0.0125, format="csr")
Expand All @@ -395,7 +400,11 @@ def test_csrmm():

def test_csr_reduce():
mod = tvm.IRModule.from_expr(csr_reduce)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
t = AxisTree({
"J": "I",
"J": None
})
mod = tvm.tir.transform.LowerSparseTIR(t)(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_csr_reduce, True)

A = sp.random(128, 128, dtype="float32", density=0.0125, format="csr")
Expand All @@ -416,7 +425,14 @@ def test_csr_reduce():

def test_bsrmm():
mod = tvm.IRModule.from_expr(bsrmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
t = AxisTree({
"J": "I",
"I": None,
"BJ": None,
"BI": None,
"F": None
})
mod = tvm.tir.transform.LowerSparseTIR(t)(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_bsrmm, True)

block_size = 16
Expand Down Expand Up @@ -456,7 +472,14 @@ def test_bsrmm():

def test_ellpack_mm():
mod = tvm.IRModule.from_expr(ellpack_mm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
t = AxisTree({
"J": "I",
"I": None,
"F": None,
"BI": None,
"BJ": None
})
mod = tvm.tir.transform.LowerSparseTIR(t)(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_ellpack_mm, True)

nnz_cols = 4
Expand Down Expand Up @@ -505,13 +528,18 @@ def test_ellpack_mm():

def test_batch_mm():
mod = tvm.IRModule.from_expr(batch_mm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
t = AxisTree({})
mod = tvm.tir.transform.LowerSparseTIR(t)(mod)
# print(mod["main"].script(tir_prefix="T"))


def test_csr_element_wise():
mod = tvm.IRModule.from_expr(csr_element_wise)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
t = AxisTree({
"J": "I",
"I": None
})
mod = tvm.tir.transform.LowerSparseTIR(t)(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_csr_element_wise, True)

A = sp.random(128, 128, dtype="float32", density=0.0125, format="csr")
Expand Down