Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Use shared.dyn for Tensor Core Schedule Rules #13891

Merged
merged 1 commit into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 112 additions & 54 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
"""Intrinsics for tensorization on NVIDIA GPU."""
from typing import Dict, Tuple

from typing_extensions import Literal

from tvm.script import tir as T
from tvm.tir.function import PrimFunc

Expand Down Expand Up @@ -815,54 +817,101 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
*get_wmma_sync_intrin(16, 16, 16, "int8", "int32", True),
)

WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a"
WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_A_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", False, False),
)

WMMA_LOAD_16x16x16_F16_B_INTRIN = "wmma_load_16x16x16_f16_b"
WMMA_LOAD_16x16x16_F16_A_DYN_INTRIN = "wmma_load_16x16x16_f16_a_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_A_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", False, False),
)

WMMA_LOAD_16x16x16_F16_B_INTRIN = "wmma_load_16x16x16_f16_b_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_B_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, False),
)

WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN = "wmma_load_16x16x16_f16_a_trans"
WMMA_LOAD_16x16x16_F16_B_DYN_INTRIN = "wmma_load_16x16x16_f16_b_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_B_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", True, False),
)

WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN = "wmma_load_16x16x16_f16_a_trans_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", False, True),
)

WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN = "wmma_load_16x16x16_f16_b_trans"
WMMA_LOAD_16x16x16_F16_A_TRANS_DYN_INTRIN = "wmma_load_16x16x16_f16_a_trans_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_A_TRANS_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", False, True),
)

WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN = "wmma_load_16x16x16_f16_b_trans_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, True),
)

WMMA_LOAD_16x16x16_S8_A_INTRIN = "wmma_load_16x16x16_s8_a"
WMMA_LOAD_16x16x16_F16_B_TRANS_DYN_INTRIN = "wmma_load_16x16x16_f16_b_trans_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_B_TRANS_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", True, True),
)

WMMA_LOAD_16x16x16_S8_A_INTRIN = "wmma_load_16x16x16_s8_a_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_A_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared", False, False),
)

WMMA_LOAD_16x16x16_S8_B_INTRIN = "wmma_load_16x16x16_s8_b"
WMMA_LOAD_16x16x16_S8_A_DYN_INTRIN = "wmma_load_16x16x16_s8_a_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_A_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", False, False),
)

WMMA_LOAD_16x16x16_S8_B_INTRIN = "wmma_load_16x16x16_s8_b_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_B_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared", True, False),
)

WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN = "wmma_load_16x16x16_s8_a_trans"
WMMA_LOAD_16x16x16_S8_B_DYN_INTRIN = "wmma_load_16x16x16_s8_b_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_B_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", True, False),
)

WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN = "wmma_load_16x16x16_s8_a_trans_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared", False, True),
)

WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN = "wmma_load_16x16x16_s8_b_trans"
WMMA_LOAD_16x16x16_S8_A_TRANS_DYN_INTRIN = "wmma_load_16x16x16_s8_a_trans_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_A_TRANS_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", False, True),
)

WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN = "wmma_load_16x16x16_s8_b_trans_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared", True, True),
)

WMMA_LOAD_16x16x16_S8_B_TRANS_DYN_INTRIN = "wmma_load_16x16x16_s8_b_trans_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_B_TRANS_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", True, True),
)

WMMA_FILL_16x16x16_F32_INTRIN = "wmma_fill_16x16x16_f32"
TensorIntrin.register(WMMA_FILL_16x16x16_F32_INTRIN, *get_wmma_fill_intrin(16, 16, 16, "float32"))
Expand All @@ -878,16 +927,34 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
WMMA_STORE_16x16x16_F32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float32", "shared")
)

WMMA_STORE_16x16x16_F32_SHARED_DYN_INTRIN = "wmma_store_16x16x16_f32_shared_dyn"
TensorIntrin.register(
WMMA_STORE_16x16x16_F32_SHARED_DYN_INTRIN,
*get_wmma_store_intrin(16, 16, 16, "float32", "shared.dyn"),
)

WMMA_STORE_16x16x16_F16_SHARED_INTRIN = "wmma_store_16x16x16_f16_shared"
TensorIntrin.register(
WMMA_STORE_16x16x16_F16_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float16", "shared")
)

WMMA_STORE_16x16x16_F16_SHARED_DYN_INTRIN = "wmma_store_16x16x16_f16_shared_dyn"
TensorIntrin.register(
WMMA_STORE_16x16x16_F16_SHARED_DYN_INTRIN,
*get_wmma_store_intrin(16, 16, 16, "float16", "shared.dyn"),
)

WMMA_STORE_16x16x16_S32_SHARED_INTRIN = "wmma_store_16x16x16_s32_shared"
TensorIntrin.register(
WMMA_STORE_16x16x16_S32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "int32", "shared")
)

WMMA_STORE_16x16x16_S32_SHARED_DYN_INTRIN = "wmma_store_16x16x16_s32_shared_dyn"
TensorIntrin.register(
WMMA_STORE_16x16x16_S32_SHARED_DYN_INTRIN,
*get_wmma_store_intrin(16, 16, 16, "int32", "shared.dyn"),
)

WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN = "wmma_store_16x16x16_f32_global"
TensorIntrin.register(
WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float32", "global")
Expand All @@ -905,14 +972,21 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:


def get_wmma_intrin_group(
store_scope: str, in_dtype: str, out_dtype: str, trans_b: bool
load_scope: Literal["shared", "shared.dyn"],
store_scope: Literal["global", "shared", "shared.dyn"],
in_dtype: str,
out_dtype: str,
trans_b: bool,
) -> Dict[str, str]:
"""Get a group of intrinsics for wmma tensor core with the given configurations

Parameters
----------
store_scope : str
Must be one of ["global", "shared"]. The memory scope of the result buffer.
load_scope : Literal["shared", "shared.dyn"]
The memory scope of the input buffer.

store_scope : Literal["global", "shared", "shared.dyn"]
The memory scope of the result buffer.

in_dtype : str
The input data type.
Expand All @@ -928,51 +1002,35 @@ def get_wmma_intrin_group(
ret : Dict[str, str]
A group of tensor intrinsics.
"""
assert store_scope in ["global", "shared"]
assert load_scope in ["shared", "shared.dyn"]
assert store_scope in ["global", "shared", "shared.dyn"]
assert in_dtype in ["float16", "int8"]
assert out_dtype in ["float16", "float32", "int32"]

load_a_intrins = {
"float16": WMMA_LOAD_16x16x16_F16_A_INTRIN,
"int8": WMMA_LOAD_16x16x16_S8_A_INTRIN,
}
load_b_intrins = {
"float16": WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN
if trans_b
else WMMA_LOAD_16x16x16_F16_B_INTRIN,
"int8": WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN if trans_b else WMMA_LOAD_16x16x16_S8_B_INTRIN,
}
compute_intrins = {
"float16": WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
if trans_b
else WMMA_SYNC_16x16x16_f16f16f16_INTRIN,
"float32": WMMA_SYNC_16x16x16_f16f16f32_TRANS_INTRIN
if trans_b
else WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
"int32": WMMA_SYNC_16x16x16_s8s8s32_TRANS_INTRIN
if trans_b
else WMMA_SYNC_16x16x16_s8s8s32_INTRIN,
}
init_intrins = {
"float16": WMMA_FILL_16x16x16_F16_INTRIN,
"float32": WMMA_FILL_16x16x16_F32_INTRIN,
"int32": WMMA_FILL_16x16x16_S32_INTRIN,
}
store_intrins = {
"float16": WMMA_STORE_16x16x16_F16_SHARED_INTRIN
if store_scope == "shared"
else WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN,
"float32": WMMA_STORE_16x16x16_F32_SHARED_INTRIN
if store_scope == "shared"
else WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN,
"int32": WMMA_STORE_16x16x16_S32_SHARED_INTRIN
if store_scope == "shared"
else WMMA_STORE_16x16x16_S32_GLOBAL_INTRIN,
}
shape = "16x16x16"
in_dtype = "f16" if in_dtype == "float16" else "s8"
out_dtype = "f16" if out_dtype == "float16" else "f32" if out_dtype == "float32" else "s32"
# convert "shared.dyn" to "shared_dyn"
load_scope = load_scope.replace(".", "_")
store_scope = store_scope.replace(".", "_")
trans_a = ""
trans_b = "_trans" if trans_b else ""

# e.g. wmma_load_16x16x16_f16_a_shared
load_a_intrin = f"wmma_load_{shape}_{in_dtype}_a{trans_a}_{load_scope}"
# e.g. wmma_load_16x16x16_f16_b_trans_shared_dyn
load_b_intrin = f"wmma_load_{shape}_{in_dtype}_b{trans_b}_{load_scope}"
# e.g. wmma_sync_16x16x16_f16f16f32_trans
compute_intrin = f"wmma_sync_{shape}_{in_dtype}{in_dtype}{out_dtype}{trans_b}"
# e.g. wmma_fill_16x16x16_f16
init_intrin = f"wmma_fill_{shape}_{out_dtype}"
# e.g. wmma_store_16x16x16_f16_shared_dyn
store_intrin = f"wmma_store_{shape}_{out_dtype}_{store_scope}"

return {
"init": init_intrins[out_dtype],
"load_a": load_a_intrins[in_dtype],
"load_b": load_b_intrins[in_dtype],
"compute": compute_intrins[out_dtype],
"store": store_intrins[out_dtype],
"init": init_intrin,
"load_a": load_a_intrin,
"load_b": load_b_intrin,
"compute": compute_intrin,
"store": store_intrin,
}
40 changes: 20 additions & 20 deletions src/meta_schedule/schedule_rule/schedule_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,47 +175,47 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
// Tensor Cores f32 += f16 * f16
{
{"init", "wmma_fill_16x16x16_f32"},
{"load_a", "wmma_load_16x16x16_f16_a"},
{"load_b", "wmma_load_16x16x16_f16_b"},
{"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
{"load_b", "wmma_load_16x16x16_f16_b_shared_dyn"},
{"compute", "wmma_sync_16x16x16_f16f16f32"},
{"store", "wmma_store_16x16x16_f32_shared"},
{"store", "wmma_store_16x16x16_f32_shared_dyn"},
},
{
{"init", "wmma_fill_16x16x16_f32"},
{"load_a", "wmma_load_16x16x16_f16_a"},
{"load_b", "wmma_load_16x16x16_f16_b_trans"},
{"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
{"load_b", "wmma_load_16x16x16_f16_b_trans_shared_dyn"},
{"compute", "wmma_sync_16x16x16_f16f16f32_trans"},
{"store", "wmma_store_16x16x16_f32_shared"},
{"store", "wmma_store_16x16x16_f32_shared_dyn"},
},
// Tensor Cores f16 += f16 * f16
{
{"init", "wmma_fill_16x16x16_f16"},
{"load_a", "wmma_load_16x16x16_f16_a"},
{"load_b", "wmma_load_16x16x16_f16_b"},
{"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
{"load_b", "wmma_load_16x16x16_f16_b_shared_dyn"},
{"compute", "wmma_sync_16x16x16_f16f16f16"},
{"store", "wmma_store_16x16x16_f16_shared"},
{"store", "wmma_store_16x16x16_f16_shared_dyn"},
},
{
{"init", "wmma_fill_16x16x16_f16"},
{"load_a", "wmma_load_16x16x16_f16_a"},
{"load_b", "wmma_load_16x16x16_f16_b_trans"},
{"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
{"load_b", "wmma_load_16x16x16_f16_b_trans_shared_dyn"},
{"compute", "wmma_sync_16x16x16_f16f16f16_trans"},
{"store", "wmma_store_16x16x16_f16_shared"},
{"store", "wmma_store_16x16x16_f16_shared_dyn"},
},
// Tensor Cores s32 += s8 * s8
{
{"init", "wmma_fill_16x16x16_s32"},
{"load_a", "wmma_load_16x16x16_s8_a"},
{"load_b", "wmma_load_16x16x16_s8_b"},
{"load_a", "wmma_load_16x16x16_s8_a_shared_dyn"},
{"load_b", "wmma_load_16x16x16_s8_b_shared_dyn"},
{"compute", "wmma_sync_16x16x16_s8s8s32"},
{"store", "wmma_store_16x16x16_s32_shared"},
{"store", "wmma_store_16x16x16_s32_shared_dyn"},
},
{
{"init", "wmma_fill_16x16x16_s32"},
{"load_a", "wmma_load_16x16x16_s8_a"},
{"load_b", "wmma_load_16x16x16_s8_b_trans"},
{"load_a", "wmma_load_16x16x16_s8_a_shared_dyn"},
{"load_b", "wmma_load_16x16x16_s8_b_trans_shared_dyn"},
{"compute", "wmma_sync_16x16x16_s8s8s32_trans"},
{"store", "wmma_store_16x16x16_s32_shared"},
{"store", "wmma_store_16x16x16_s32_shared_dyn"},
},
};
Array<ScheduleRule> results{
Expand All @@ -229,11 +229,11 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
/*reuse_read=*/
Map<String, ObjectRef>{{"req", String("must")},
{"levels", Array<Integer>{4}}, //
{"scope", String("shared")}},
{"scope", String("shared.dyn")}},
/*reuse_write=*/
Map<String, ObjectRef>{{"req", String("must")},
{"levels", Array<Integer>{2}}, //
{"scope", String("shared")}},
{"scope", String("shared.dyn")}},
/*use_software_pipeline=*/false) //
};
Array<ScheduleRule> append = ScheduleRule::DefaultCUDA();
Expand Down
Loading