diff --git a/tests/python/unittest/test_mma_16x8x16_4k_tune.py b/tests/python/unittest/test_mma_16x8x16_4k_tune.py index ce52f390c51d..e7f5454e9e59 100644 --- a/tests/python/unittest/test_mma_16x8x16_4k_tune.py +++ b/tests/python/unittest/test_mma_16x8x16_4k_tune.py @@ -127,17 +127,19 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: with T.block("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) T.reads( - C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2], - A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], - B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2], + C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2], + A[i % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 2], + B[k % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + k % 16 // 8 * 2 + j % 2], ) - T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]) - C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[ - i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2 + T.writes(C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2]) + C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2] = C[ + i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2 ] + T.cast( - A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], "float32" + A[i % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 8 % 2], + "float32", ) * T.cast( - B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2], "float32" + B[k % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + k % 16 // 8 * 2 + j % 8 % 2], + "float32", ) @@ -242,11 +244,19 @@ def mma_fill_desc(a: T.handle) -> None: T.writes(C_warp[0:32, 0:8]) for i0, i1 in T.grid(32, 8): with T.block("C_warp"): - i = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4) - j = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4) + i_init = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4) + j_init = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4) T.reads() - T.writes(C_warp[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]) - C_warp[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = T.float32(0) + T.writes( + C_warp[ + i_init % 8 * 4 + j_init % 8 // 2, + j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 2, + ] + ) + C_warp[ + i_init % 8 * 4 + j_init % 8 // 2, + j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 8 % 2, + ] = T.float32(0) @T.prim_func @@ -304,8 +314,8 @@ def schedule(sch: tir.Schedule): num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2]) else: i_factors = [1, 16, 4, 2, 2] - j_factors = [1, 64, 1, 8, 1] - k_factors = [128, 4, 1] + j_factors = [1, 32, 1, 8, 1] + k_factors = [64, 4, 1] num_ty = i_factors[2] * j_factors[2] i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors) @@ -368,7 +378,7 @@ def fetch_to_shared(block, idx, ndim): ii, jj = sch.get_loops(C_warp)[-2:] io, ii = sch.split(ii, factors=[None, 16]) - jo, ji = sch.split(jj, factors=[None, 8]) + jo, ji = sch.split(jj, factors=[None, 16]) sch.reorder(io, jo, ii, ji) block_init_c = sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) @@ -394,18 +404,10 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j): loop_b = tile_wmma_fragment(B_warp, 16) sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout) - sch.transform_layout( - B_warp, - 0, - "write", - index_map=shared_16x16_to_ldmatrix_32x8_layout - ) - sch.transform_layout( - C_warp, - 0, - "read", - index_map=shared_16x16_to_ldmatrix_32x8_layout - ) + sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout) + sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout) + + # return if use_ldmatrix: sch.tensorize(loop_a, "mma.ldmatrix_a") @@ -425,30 +427,27 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j): fused_1 = sch.fuse(warp_loop2, f_0) sch.bind(fused_1, "threadIdx.x") - # mma_loop = sch.get_loops(block_inner)[-3] - # sch.tensorize(mma_loop, "mma_sync") - - # block_init_c = sch.get_block("C_init") - # init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:] - # f_0, f_1 = sch.split(init_loop1, factors=[None, 8]) - # f_2, f_3 = sch.split(init_loop2, factors=[None, 2]) - # sch.reorder(f_1, f_2, f_0, f_3) - # fused_1 = sch.fuse(f_1, f_2) - # fused_2 = sch.fuse(f_0, f_3) - # # sch.bind(fused_1, "threadIdx.x") - # sch.tensorize(fused_1, "mma_fill") - - # warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:] - # f_0, f_1 = sch.split(warp_loop1, factors=[None, 8]) - # f_2, f_3 = sch.split(warp_loop2, factors=[None, 2]) - # sch.reorder(f_1, f_2, f_0, f_3) - # fused_1 = sch.fuse(f_1, f_2) - # fused_2 = sch.fuse(f_0, f_3) - - # # print(sch.mod.script()) - # # return - - # sch.tensorize(fused_1, "mma_store") + mma_loop = sch.get_loops(block_inner)[-3] + sch.tensorize(mma_loop, "mma_sync") + + block_init_c = sch.get_block("C_init") + init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:] + f_0, f_1 = sch.split(init_loop1, factors=[None, 8]) + f_2, f_3 = sch.split(init_loop2, factors=[None, 4]) + sch.reorder(f_1, f_2, f_0, f_3) + fused_1 = sch.fuse(f_1, f_2) + fused_2 = sch.fuse(f_0, f_3) + sch.tensorize(fused_1, "mma_fill") + + warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:] + f_0, f_1 = sch.split(warp_loop1, factors=[None, 8]) + outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2]) + sch.reorder(outer, f_1, f_2, f_0, f_3) + fused_1 = sch.fuse(f_1, f_2) + fused_2 = sch.fuse(f_0, f_3) + sch.tensorize(outer, "mma_store") + # print(sch.mod.script()) + # return ir_module = tvm.IRModule({"main": workload}) @@ -456,38 +455,37 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j): schedule(sch) print(sch.mod.script()) -# if tune: -# with tempfile.TemporaryDirectory() as work_dir: -# sch = ms.tune_tir( -# mod=workload, -# target=tvm.target.Target("nvidia/geforce-rtx-3070"), -# config=ms.TuneConfig( -# strategy="evolutionary", -# num_trials_per_iter=32, -# max_trials_per_task=128, -# max_trials_global=128, -# ), -# work_dir=work_dir, -# space=ms.space_generator.ScheduleFn(schedule), -# ) -# if sch is None: -# print("No valid schedule found!") -# else: -# print(sch.mod.script()) -# print(sch.trace) -# else: -# print(sch.mod.script()) -# target = "cuda" -# f = tvm.build(sch.mod["main"], target=target, name="dense") - -# dev = tvm.device("cuda", 0) -# a_np = np.random.uniform(size=(N, K)).astype("float16") -# b_np = np.random.uniform(size=(K, M)).astype("float16") -# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) -# a = tvm.nd.array(a_np, dev) -# b = tvm.nd.array(b_np, dev) -# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev) -# f = tvm.build(sch.mod["main"], target="cuda", name="dense") +if tune: + with tempfile.TemporaryDirectory() as work_dir: + sch = ms.tune_tir( + mod=workload, + target=tvm.target.Target("nvidia/geforce-rtx-3070"), + config=ms.TuneConfig( + strategy="evolutionary", + num_trials_per_iter=32, + max_trials_per_task=128, + max_trials_global=128, + ), + work_dir=work_dir, + space=ms.space_generator.ScheduleFn(schedule), + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) +else: + target = "cuda" + f = tvm.build(sch.mod["main"], target=target, name="dense") + +dev = tvm.device("cuda", 0) +a_np = np.random.uniform(size=(N, K)).astype("float16") +b_np = np.random.uniform(size=(K, M)).astype("float16") +c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) +a = tvm.nd.array(a_np, dev) +b = tvm.nd.array(b_np, dev) +c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev) +f = tvm.build(sch.mod["main"], target="cuda", name="dense") # print(f.imported_modules[0].get_source()) # f(a, b, c)