Skip to content

Commit

Permalink
tensoriz fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 68039b0 commit c3cb170
Showing 1 changed file with 81 additions and 83 deletions.
164 changes: 81 additions & 83 deletions tests/python/unittest/test_mma_16x8x16_4k_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,19 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
T.reads(
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2],
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2],
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2],
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2],
A[i % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 2],
B[k % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + k % 16 // 8 * 2 + j % 2],
)
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[
i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2
T.writes(C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2])
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2] = C[
i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2
] + T.cast(
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], "float32"
A[i % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 8 % 2],
"float32",
) * T.cast(
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2], "float32"
B[k % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + k % 16 // 8 * 2 + j % 8 % 2],
"float32",
)


Expand Down Expand Up @@ -242,11 +244,19 @@ def mma_fill_desc(a: T.handle) -> None:
T.writes(C_warp[0:32, 0:8])
for i0, i1 in T.grid(32, 8):
with T.block("C_warp"):
i = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4)
j = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4)
i_init = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4)
j_init = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4)
T.reads()
T.writes(C_warp[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
C_warp[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = T.float32(0)
T.writes(
C_warp[
i_init % 8 * 4 + j_init % 8 // 2,
j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 2,
]
)
C_warp[
i_init % 8 * 4 + j_init % 8 // 2,
j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 8 % 2,
] = T.float32(0)


@T.prim_func
Expand Down Expand Up @@ -304,8 +314,8 @@ def schedule(sch: tir.Schedule):
num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2])
else:
i_factors = [1, 16, 4, 2, 2]
j_factors = [1, 64, 1, 8, 1]
k_factors = [128, 4, 1]
j_factors = [1, 32, 1, 8, 1]
k_factors = [64, 4, 1]
num_ty = i_factors[2] * j_factors[2]

i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
Expand Down Expand Up @@ -368,7 +378,7 @@ def fetch_to_shared(block, idx, ndim):

ii, jj = sch.get_loops(C_warp)[-2:]
io, ii = sch.split(ii, factors=[None, 16])
jo, ji = sch.split(jj, factors=[None, 8])
jo, ji = sch.split(jj, factors=[None, 16])
sch.reorder(io, jo, ii, ji)

block_init_c = sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
Expand All @@ -394,18 +404,10 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
loop_b = tile_wmma_fragment(B_warp, 16)

sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
sch.transform_layout(
B_warp,
0,
"write",
index_map=shared_16x16_to_ldmatrix_32x8_layout
)
sch.transform_layout(
C_warp,
0,
"read",
index_map=shared_16x16_to_ldmatrix_32x8_layout
)
sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout)

# return

if use_ldmatrix:
sch.tensorize(loop_a, "mma.ldmatrix_a")
Expand All @@ -425,69 +427,65 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
fused_1 = sch.fuse(warp_loop2, f_0)
sch.bind(fused_1, "threadIdx.x")

# mma_loop = sch.get_loops(block_inner)[-3]
# sch.tensorize(mma_loop, "mma_sync")

# block_init_c = sch.get_block("C_init")
# init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
# f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
# f_2, f_3 = sch.split(init_loop2, factors=[None, 2])
# sch.reorder(f_1, f_2, f_0, f_3)
# fused_1 = sch.fuse(f_1, f_2)
# fused_2 = sch.fuse(f_0, f_3)
# # sch.bind(fused_1, "threadIdx.x")
# sch.tensorize(fused_1, "mma_fill")

# warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
# f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
# f_2, f_3 = sch.split(warp_loop2, factors=[None, 2])
# sch.reorder(f_1, f_2, f_0, f_3)
# fused_1 = sch.fuse(f_1, f_2)
# fused_2 = sch.fuse(f_0, f_3)

# # print(sch.mod.script())
# # return

# sch.tensorize(fused_1, "mma_store")
mma_loop = sch.get_loops(block_inner)[-3]
sch.tensorize(mma_loop, "mma_sync")

block_init_c = sch.get_block("C_init")
init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
f_2, f_3 = sch.split(init_loop2, factors=[None, 4])
sch.reorder(f_1, f_2, f_0, f_3)
fused_1 = sch.fuse(f_1, f_2)
fused_2 = sch.fuse(f_0, f_3)
sch.tensorize(fused_1, "mma_fill")

warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2])
sch.reorder(outer, f_1, f_2, f_0, f_3)
fused_1 = sch.fuse(f_1, f_2)
fused_2 = sch.fuse(f_0, f_3)
sch.tensorize(outer, "mma_store")
# print(sch.mod.script())
# return


ir_module = tvm.IRModule({"main": workload})
sch = tvm.tir.Schedule(ir_module)
schedule(sch)
print(sch.mod.script())

# if tune:
# with tempfile.TemporaryDirectory() as work_dir:
# sch = ms.tune_tir(
# mod=workload,
# target=tvm.target.Target("nvidia/geforce-rtx-3070"),
# config=ms.TuneConfig(
# strategy="evolutionary",
# num_trials_per_iter=32,
# max_trials_per_task=128,
# max_trials_global=128,
# ),
# work_dir=work_dir,
# space=ms.space_generator.ScheduleFn(schedule),
# )
# if sch is None:
# print("No valid schedule found!")
# else:
# print(sch.mod.script())
# print(sch.trace)
# else:
# print(sch.mod.script())
# target = "cuda"
# f = tvm.build(sch.mod["main"], target=target, name="dense")

# dev = tvm.device("cuda", 0)
# a_np = np.random.uniform(size=(N, K)).astype("float16")
# b_np = np.random.uniform(size=(K, M)).astype("float16")
# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
# a = tvm.nd.array(a_np, dev)
# b = tvm.nd.array(b_np, dev)
# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
# f = tvm.build(sch.mod["main"], target="cuda", name="dense")
if tune:
with tempfile.TemporaryDirectory() as work_dir:
sch = ms.tune_tir(
mod=workload,
target=tvm.target.Target("nvidia/geforce-rtx-3070"),
config=ms.TuneConfig(
strategy="evolutionary",
num_trials_per_iter=32,
max_trials_per_task=128,
max_trials_global=128,
),
work_dir=work_dir,
space=ms.space_generator.ScheduleFn(schedule),
)
if sch is None:
print("No valid schedule found!")
else:
print(sch.mod.script())
print(sch.trace)
else:
target = "cuda"
f = tvm.build(sch.mod["main"], target=target, name="dense")

dev = tvm.device("cuda", 0)
a_np = np.random.uniform(size=(N, K)).astype("float16")
b_np = np.random.uniform(size=(K, M)).astype("float16")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
f = tvm.build(sch.mod["main"], target="cuda", name="dense")

# print(f.imported_modules[0].get_source())
# f(a, b, c)
Expand Down

0 comments on commit c3cb170

Please sign in to comment.