diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 0cde7f246465..0703811ea79f 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -18,6 +18,8 @@ """Intrinsics for tensorization on NVIDIA GPU.""" from typing import Dict, Tuple +from typing_extensions import Literal + from tvm.script import tir as T from tvm.tir.function import PrimFunc @@ -815,54 +817,101 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: *get_wmma_sync_intrin(16, 16, 16, "int8", "int32", True), ) -WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a" +WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a_shared" TensorIntrin.register( WMMA_LOAD_16x16x16_F16_A_INTRIN, *get_wmma_load_intrin(16, 16, 16, "float16", "shared", False, False), ) -WMMA_LOAD_16x16x16_F16_B_INTRIN = "wmma_load_16x16x16_f16_b" +WMMA_LOAD_16x16x16_F16_A_DYN_INTRIN = "wmma_load_16x16x16_f16_a_shared_dyn" +TensorIntrin.register( + WMMA_LOAD_16x16x16_F16_A_DYN_INTRIN, + *get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", False, False), +) + +WMMA_LOAD_16x16x16_F16_B_INTRIN = "wmma_load_16x16x16_f16_b_shared" TensorIntrin.register( WMMA_LOAD_16x16x16_F16_B_INTRIN, *get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, False), ) -WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN = "wmma_load_16x16x16_f16_a_trans" +WMMA_LOAD_16x16x16_F16_B_DYN_INTRIN = "wmma_load_16x16x16_f16_b_shared_dyn" +TensorIntrin.register( + WMMA_LOAD_16x16x16_F16_B_DYN_INTRIN, + *get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", True, False), +) + +WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN = "wmma_load_16x16x16_f16_a_trans_shared" TensorIntrin.register( WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN, *get_wmma_load_intrin(16, 16, 16, "float16", "shared", False, True), ) -WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN = "wmma_load_16x16x16_f16_b_trans" +WMMA_LOAD_16x16x16_F16_A_TRANS_DYN_INTRIN = "wmma_load_16x16x16_f16_a_trans_shared_dyn" +TensorIntrin.register( + WMMA_LOAD_16x16x16_F16_A_TRANS_DYN_INTRIN, + *get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", False, True), +) + +WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN = "wmma_load_16x16x16_f16_b_trans_shared" TensorIntrin.register( WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN, *get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, True), ) -WMMA_LOAD_16x16x16_S8_A_INTRIN = "wmma_load_16x16x16_s8_a" +WMMA_LOAD_16x16x16_F16_B_TRANS_DYN_INTRIN = "wmma_load_16x16x16_f16_b_trans_shared_dyn" +TensorIntrin.register( + WMMA_LOAD_16x16x16_F16_B_TRANS_DYN_INTRIN, + *get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", True, True), +) + +WMMA_LOAD_16x16x16_S8_A_INTRIN = "wmma_load_16x16x16_s8_a_shared" TensorIntrin.register( WMMA_LOAD_16x16x16_S8_A_INTRIN, *get_wmma_load_intrin(16, 16, 16, "int8", "shared", False, False), ) -WMMA_LOAD_16x16x16_S8_B_INTRIN = "wmma_load_16x16x16_s8_b" +WMMA_LOAD_16x16x16_S8_A_DYN_INTRIN = "wmma_load_16x16x16_s8_a_shared_dyn" +TensorIntrin.register( + WMMA_LOAD_16x16x16_S8_A_DYN_INTRIN, + *get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", False, False), +) + +WMMA_LOAD_16x16x16_S8_B_INTRIN = "wmma_load_16x16x16_s8_b_shared" TensorIntrin.register( WMMA_LOAD_16x16x16_S8_B_INTRIN, *get_wmma_load_intrin(16, 16, 16, "int8", "shared", True, False), ) -WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN = "wmma_load_16x16x16_s8_a_trans" +WMMA_LOAD_16x16x16_S8_B_DYN_INTRIN = "wmma_load_16x16x16_s8_b_shared_dyn" +TensorIntrin.register( + WMMA_LOAD_16x16x16_S8_B_DYN_INTRIN, + *get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", True, False), +) + +WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN = "wmma_load_16x16x16_s8_a_trans_shared" TensorIntrin.register( WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN, *get_wmma_load_intrin(16, 16, 16, "int8", "shared", False, True), ) -WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN = "wmma_load_16x16x16_s8_b_trans" +WMMA_LOAD_16x16x16_S8_A_TRANS_DYN_INTRIN = "wmma_load_16x16x16_s8_a_trans_shared_dyn" +TensorIntrin.register( + WMMA_LOAD_16x16x16_S8_A_TRANS_DYN_INTRIN, + *get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", False, True), +) + +WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN = "wmma_load_16x16x16_s8_b_trans_shared" TensorIntrin.register( WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN, *get_wmma_load_intrin(16, 16, 16, "int8", "shared", True, True), ) +WMMA_LOAD_16x16x16_S8_B_TRANS_DYN_INTRIN = "wmma_load_16x16x16_s8_b_trans_shared_dyn" +TensorIntrin.register( + WMMA_LOAD_16x16x16_S8_B_TRANS_DYN_INTRIN, + *get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", True, True), +) WMMA_FILL_16x16x16_F32_INTRIN = "wmma_fill_16x16x16_f32" TensorIntrin.register(WMMA_FILL_16x16x16_F32_INTRIN, *get_wmma_fill_intrin(16, 16, 16, "float32")) @@ -878,16 +927,34 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: WMMA_STORE_16x16x16_F32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float32", "shared") ) +WMMA_STORE_16x16x16_F32_SHARED_DYN_INTRIN = "wmma_store_16x16x16_f32_shared_dyn" +TensorIntrin.register( + WMMA_STORE_16x16x16_F32_SHARED_DYN_INTRIN, + *get_wmma_store_intrin(16, 16, 16, "float32", "shared.dyn"), +) + WMMA_STORE_16x16x16_F16_SHARED_INTRIN = "wmma_store_16x16x16_f16_shared" TensorIntrin.register( WMMA_STORE_16x16x16_F16_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float16", "shared") ) +WMMA_STORE_16x16x16_F16_SHARED_DYN_INTRIN = "wmma_store_16x16x16_f16_shared_dyn" +TensorIntrin.register( + WMMA_STORE_16x16x16_F16_SHARED_DYN_INTRIN, + *get_wmma_store_intrin(16, 16, 16, "float16", "shared.dyn"), +) + WMMA_STORE_16x16x16_S32_SHARED_INTRIN = "wmma_store_16x16x16_s32_shared" TensorIntrin.register( WMMA_STORE_16x16x16_S32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "int32", "shared") ) +WMMA_STORE_16x16x16_S32_SHARED_DYN_INTRIN = "wmma_store_16x16x16_s32_shared_dyn" +TensorIntrin.register( + WMMA_STORE_16x16x16_S32_SHARED_DYN_INTRIN, + *get_wmma_store_intrin(16, 16, 16, "int32", "shared.dyn"), +) + WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN = "wmma_store_16x16x16_f32_global" TensorIntrin.register( WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float32", "global") @@ -905,14 +972,21 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: def get_wmma_intrin_group( - store_scope: str, in_dtype: str, out_dtype: str, trans_b: bool + load_scope: Literal["shared", "shared.dyn"], + store_scope: Literal["global", "shared", "shared.dyn"], + in_dtype: str, + out_dtype: str, + trans_b: bool, ) -> Dict[str, str]: """Get a group of intrinsics for wmma tensor core with the given configurations Parameters ---------- - store_scope : str - Must be one of ["global", "shared"]. The memory scope of the result buffer. + load_scope : Literal["shared", "shared.dyn"] + The memory scope of the input buffer. + + store_scope : Literal["global", "shared", "shared.dyn"] + The memory scope of the result buffer. in_dtype : str The input data type. @@ -928,51 +1002,35 @@ def get_wmma_intrin_group( ret : Dict[str, str] A group of tensor intrinsics. """ - assert store_scope in ["global", "shared"] + assert load_scope in ["shared", "shared.dyn"] + assert store_scope in ["global", "shared", "shared.dyn"] assert in_dtype in ["float16", "int8"] assert out_dtype in ["float16", "float32", "int32"] - load_a_intrins = { - "float16": WMMA_LOAD_16x16x16_F16_A_INTRIN, - "int8": WMMA_LOAD_16x16x16_S8_A_INTRIN, - } - load_b_intrins = { - "float16": WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN - if trans_b - else WMMA_LOAD_16x16x16_F16_B_INTRIN, - "int8": WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN if trans_b else WMMA_LOAD_16x16x16_S8_B_INTRIN, - } - compute_intrins = { - "float16": WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN - if trans_b - else WMMA_SYNC_16x16x16_f16f16f16_INTRIN, - "float32": WMMA_SYNC_16x16x16_f16f16f32_TRANS_INTRIN - if trans_b - else WMMA_SYNC_16x16x16_f16f16f32_INTRIN, - "int32": WMMA_SYNC_16x16x16_s8s8s32_TRANS_INTRIN - if trans_b - else WMMA_SYNC_16x16x16_s8s8s32_INTRIN, - } - init_intrins = { - "float16": WMMA_FILL_16x16x16_F16_INTRIN, - "float32": WMMA_FILL_16x16x16_F32_INTRIN, - "int32": WMMA_FILL_16x16x16_S32_INTRIN, - } - store_intrins = { - "float16": WMMA_STORE_16x16x16_F16_SHARED_INTRIN - if store_scope == "shared" - else WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN, - "float32": WMMA_STORE_16x16x16_F32_SHARED_INTRIN - if store_scope == "shared" - else WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, - "int32": WMMA_STORE_16x16x16_S32_SHARED_INTRIN - if store_scope == "shared" - else WMMA_STORE_16x16x16_S32_GLOBAL_INTRIN, - } + shape = "16x16x16" + in_dtype = "f16" if in_dtype == "float16" else "s8" + out_dtype = "f16" if out_dtype == "float16" else "f32" if out_dtype == "float32" else "s32" + # convert "shared.dyn" to "shared_dyn" + load_scope = load_scope.replace(".", "_") + store_scope = store_scope.replace(".", "_") + trans_a = "" + trans_b = "_trans" if trans_b else "" + + # e.g. wmma_load_16x16x16_f16_a_shared + load_a_intrin = f"wmma_load_{shape}_{in_dtype}_a{trans_a}_{load_scope}" + # e.g. wmma_load_16x16x16_f16_b_trans_shared_dyn + load_b_intrin = f"wmma_load_{shape}_{in_dtype}_b{trans_b}_{load_scope}" + # e.g. wmma_sync_16x16x16_f16f16f32_trans + compute_intrin = f"wmma_sync_{shape}_{in_dtype}{in_dtype}{out_dtype}{trans_b}" + # e.g. wmma_fill_16x16x16_f16 + init_intrin = f"wmma_fill_{shape}_{out_dtype}" + # e.g. wmma_store_16x16x16_f16_shared_dyn + store_intrin = f"wmma_store_{shape}_{out_dtype}_{store_scope}" + return { - "init": init_intrins[out_dtype], - "load_a": load_a_intrins[in_dtype], - "load_b": load_b_intrins[in_dtype], - "compute": compute_intrins[out_dtype], - "store": store_intrins[out_dtype], + "init": init_intrin, + "load_a": load_a_intrin, + "load_b": load_b_intrin, + "compute": compute_intrin, + "store": store_intrin, } diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 938d39377f1c..49a7c9911c01 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -175,47 +175,47 @@ Array ScheduleRule::DefaultCUDATensorCore() { // Tensor Cores f32 += f16 * f16 { {"init", "wmma_fill_16x16x16_f32"}, - {"load_a", "wmma_load_16x16x16_f16_a"}, - {"load_b", "wmma_load_16x16x16_f16_b"}, + {"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"}, + {"load_b", "wmma_load_16x16x16_f16_b_shared_dyn"}, {"compute", "wmma_sync_16x16x16_f16f16f32"}, - {"store", "wmma_store_16x16x16_f32_shared"}, + {"store", "wmma_store_16x16x16_f32_shared_dyn"}, }, { {"init", "wmma_fill_16x16x16_f32"}, - {"load_a", "wmma_load_16x16x16_f16_a"}, - {"load_b", "wmma_load_16x16x16_f16_b_trans"}, + {"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"}, + {"load_b", "wmma_load_16x16x16_f16_b_trans_shared_dyn"}, {"compute", "wmma_sync_16x16x16_f16f16f32_trans"}, - {"store", "wmma_store_16x16x16_f32_shared"}, + {"store", "wmma_store_16x16x16_f32_shared_dyn"}, }, // Tensor Cores f16 += f16 * f16 { {"init", "wmma_fill_16x16x16_f16"}, - {"load_a", "wmma_load_16x16x16_f16_a"}, - {"load_b", "wmma_load_16x16x16_f16_b"}, + {"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"}, + {"load_b", "wmma_load_16x16x16_f16_b_shared_dyn"}, {"compute", "wmma_sync_16x16x16_f16f16f16"}, - {"store", "wmma_store_16x16x16_f16_shared"}, + {"store", "wmma_store_16x16x16_f16_shared_dyn"}, }, { {"init", "wmma_fill_16x16x16_f16"}, - {"load_a", "wmma_load_16x16x16_f16_a"}, - {"load_b", "wmma_load_16x16x16_f16_b_trans"}, + {"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"}, + {"load_b", "wmma_load_16x16x16_f16_b_trans_shared_dyn"}, {"compute", "wmma_sync_16x16x16_f16f16f16_trans"}, - {"store", "wmma_store_16x16x16_f16_shared"}, + {"store", "wmma_store_16x16x16_f16_shared_dyn"}, }, // Tensor Cores s32 += s8 * s8 { {"init", "wmma_fill_16x16x16_s32"}, - {"load_a", "wmma_load_16x16x16_s8_a"}, - {"load_b", "wmma_load_16x16x16_s8_b"}, + {"load_a", "wmma_load_16x16x16_s8_a_shared_dyn"}, + {"load_b", "wmma_load_16x16x16_s8_b_shared_dyn"}, {"compute", "wmma_sync_16x16x16_s8s8s32"}, - {"store", "wmma_store_16x16x16_s32_shared"}, + {"store", "wmma_store_16x16x16_s32_shared_dyn"}, }, { {"init", "wmma_fill_16x16x16_s32"}, - {"load_a", "wmma_load_16x16x16_s8_a"}, - {"load_b", "wmma_load_16x16x16_s8_b_trans"}, + {"load_a", "wmma_load_16x16x16_s8_a_shared_dyn"}, + {"load_b", "wmma_load_16x16x16_s8_b_trans_shared_dyn"}, {"compute", "wmma_sync_16x16x16_s8s8s32_trans"}, - {"store", "wmma_store_16x16x16_s32_shared"}, + {"store", "wmma_store_16x16x16_s32_shared_dyn"}, }, }; Array results{ @@ -229,11 +229,11 @@ Array ScheduleRule::DefaultCUDATensorCore() { /*reuse_read=*/ Map{{"req", String("must")}, {"levels", Array{4}}, // - {"scope", String("shared")}}, + {"scope", String("shared.dyn")}}, /*reuse_write=*/ Map{{"req", String("must")}, {"levels", Array{2}}, // - {"scope", String("shared")}}, + {"scope", String("shared.dyn")}}, /*use_software_pipeline=*/false) // }; Array append = ScheduleRule::DefaultCUDA(); diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py index 73b2c990f08a..064769915955 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import pytest + import tvm import tvm.testing from tvm import meta_schedule as ms @@ -31,13 +34,15 @@ def multi_level_tiling_tensor_core( *, + read_reuse_scope="shared", write_reuse_scope="shared", in_dtype="float16", out_dtype="float32", trans_b=False, use_software_pipeline=False, ) -> ms.schedule_rule.ScheduleRule: - assert write_reuse_scope in ["shared", "global"] + assert read_reuse_scope in ["shared", "shared.dyn"] + assert write_reuse_scope in ["shared", "shared.dyn", "global"] if not isinstance(in_dtype, list): in_dtype = [in_dtype] if not isinstance(out_dtype, list): @@ -46,7 +51,9 @@ def multi_level_tiling_tensor_core( trans_b = [trans_b] return ms.schedule_rule.MultiLevelTilingTensorCore( intrin_groups=[ - get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b) + get_wmma_intrin_group( + read_reuse_scope, write_reuse_scope, _in_dtype, _out_dtype, _trans_b + ) for _in_dtype in in_dtype for _out_dtype in out_dtype for _trans_b in trans_b @@ -58,10 +65,10 @@ def multi_level_tiling_tensor_core( reuse_read=ms.schedule_rule.ReuseType( req="must", levels=[4], - scope="shared", + scope=read_reuse_scope, ), reuse_write=ms.schedule_rule.ReuseType( - req="must" if write_reuse_scope == "shared" else "no", + req="must" if write_reuse_scope.startswith("shared") else "no", levels=[2], scope=write_reuse_scope, ), @@ -69,15 +76,17 @@ def multi_level_tiling_tensor_core( ) -def test_matmul_relu(): +@pytest.mark.parametrize("shared_scope", ["shared", "shared.dyn"]) +def test_matmul_relu(shared_scope): + intrin_suffix = shared_scope.replace(".", "_") # fmt: off @T.prim_func def matmul_relu_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "float16"], compute: T.Buffer[(128, 128), "float32"]) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope=shared_scope) C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") + A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) + B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"): @@ -107,7 +116,7 @@ def matmul_relu_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "f v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0) T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -120,7 +129,7 @@ def matmul_relu_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "f v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -155,7 +164,7 @@ def matmul_relu_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "f v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("C_reindex_shared_wmma.accumulator"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -196,7 +205,9 @@ def matmul_relu_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "f target=tvm.target.Target("cuda"), types=None, sch_rules=[ - multi_level_tiling_tensor_core(), + multi_level_tiling_tensor_core( + read_reuse_scope=shared_scope, write_reuse_scope=shared_scope + ), ] + get_rules(kind="cuda", types=ms.schedule_rule.AutoInline), ) @@ -249,7 +260,7 @@ def matmul_relu_fallback_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax1_0) T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -262,7 +273,7 @@ def matmul_relu_fallback_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0) T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -355,16 +366,18 @@ def matmul_relu_fallback_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, ) -def test_conv2d(): +@pytest.mark.parametrize("shared_scope", ["shared", "shared.dyn"]) +def test_conv2d(shared_scope): + intrin_suffix = shared_scope.replace(".", "_") # fmt: off @T.prim_func def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, 3, 32, 32), "float16"], conv2d_nhwc: T.Buffer[(1, 16, 16, 32), "float32"]) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) PadInput = T.alloc_buffer([1, 18, 18, 32], dtype="float16") - conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 32], dtype="float32", scope="shared") + conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 32], dtype="float32", scope=shared_scope) conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 32], dtype="float32", scope="wmma.accumulator") - PadInput_reindex_shared = T.alloc_buffer([256, 288], dtype="float16", scope="shared") - weight_reindex_shared = T.alloc_buffer([288, 32], dtype="float16", scope="shared") + PadInput_reindex_shared = T.alloc_buffer([256, 288], dtype="float16", scope=shared_scope) + weight_reindex_shared = T.alloc_buffer([288, 32], dtype="float16", scope=shared_scope) PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 288], dtype="float16", scope="wmma.matrix_a") weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([288, 32], dtype="float16", scope="wmma.matrix_b") for i0, i1, i2, i3 in T.grid(1, 18, 18, 32): @@ -400,7 +413,7 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, v1_o = T.axis.spatial(18, ax2_0_1 + ax1_0) T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("PadInput_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -413,7 +426,7 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0) T.reads(weight_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("weight_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -448,7 +461,7 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -492,7 +505,9 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, target=tvm.target.Target("cuda"), types=None, sch_rules=[ - multi_level_tiling_tensor_core(), + multi_level_tiling_tensor_core( + read_reuse_scope=shared_scope, write_reuse_scope=shared_scope + ), ], ) check_sketches( @@ -511,6 +526,8 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, types=None, sch_rules=[ multi_level_tiling_tensor_core( + read_reuse_scope=shared_scope, + write_reuse_scope=shared_scope, in_dtype="float16", out_dtype=["float16", "float32"], ), @@ -524,7 +541,9 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, ) -def test_matmul_relu_pipeline(): +@pytest.mark.parametrize("shared_scope", ["shared", "shared.dyn"]) +def test_matmul_relu_pipeline(shared_scope): + intrin_suffix = shared_scope.replace(".", "_") # fmt: off @T.prim_func def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "float16"], compute: T.Buffer[(128, 128), "float32"]) -> None: @@ -533,10 +552,10 @@ def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, # body # with T.block("root") C = T.alloc_buffer([128, 128], dtype="float32") - C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope=shared_scope) C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") + A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) + B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): @@ -566,7 +585,7 @@ def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, v1_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax1_0) T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -579,7 +598,7 @@ def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0) T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -614,7 +633,7 @@ def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0) T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("C_reindex_shared_wmma.accumulator"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -660,6 +679,8 @@ def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, types=None, sch_rules=[ multi_level_tiling_tensor_core( + read_reuse_scope=shared_scope, + write_reuse_scope=shared_scope, use_software_pipeline=True, ), ], @@ -713,7 +734,7 @@ def matmul_relu_global_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 1 v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax1_0) T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -726,7 +747,7 @@ def matmul_relu_global_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 1 v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0) T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -868,7 +889,7 @@ def padded_matmul_relu_0(A: T.Buffer[(127, 127), "float16"], B: T.Buffer[(127, 1 v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0) T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -881,7 +902,7 @@ def padded_matmul_relu_0(A: T.Buffer[(127, 127), "float16"], B: T.Buffer[(127, 1 v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -1008,7 +1029,7 @@ def conv2d_1x1_0(inputs: T.Buffer[(1, 16, 16, 64), "float16"], weight: T.Buffer[ v1_o = T.axis.spatial(4, ax1_0_1) T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) for ax0_1_1, ax1_1_1 in T.grid(16, 16): with T.block("PadInput_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) @@ -1021,7 +1042,7 @@ def conv2d_1x1_0(inputs: T.Buffer[(1, 16, 16, 64), "float16"], weight: T.Buffer[ v3_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0) T.reads(weight_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) for ax2_1, ax3_1 in T.grid(16, 16): with T.block("weight_reindex_shared_wmma.matrix_b"): v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py index aadc530a9ba8..c242f63b98ea 100644 --- a/tests/python/unittest/test_meta_schedule_trace_apply.py +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -1743,7 +1743,7 @@ def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), " v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1 + ax1_0_1) T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a"}) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a_shared"}) for ax0_1_1, ax1_1_1 in T.grid(16, 16): with T.block("pad_temp_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) @@ -1757,7 +1757,7 @@ def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), " v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1 + ax3_0) T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans"}) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans_shared"}) for ax2_1, ax3_1 in T.grid(16, 16): with T.block("p1_reindex_shared_wmma.matrix_b"): v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) @@ -2312,7 +2312,7 @@ def apply_trace(sch): sch.annotate( block_or_loop=b158, ann_key="meta_schedule.auto_tensorize", - ann_val="wmma_load_16x16x16_s8_a", + ann_val="wmma_load_16x16x16_s8_a_shared", ) b159 = sch.cache_read(block=b38, read_buffer_index=1, storage_scope="wmma.matrix_b") sch.compute_at(block=b159, loop=l80, preserve_unit_loops=True, index=-1) @@ -2355,7 +2355,7 @@ def apply_trace(sch): sch.annotate( block_or_loop=b192, ann_key="meta_schedule.auto_tensorize", - ann_val="wmma_load_16x16x16_s8_b_trans", + ann_val="wmma_load_16x16x16_s8_b_trans_shared", ) sch.compute_inline(block=b17) sch.compute_inline(block=b18) @@ -2490,10 +2490,10 @@ def apply_trace(sch): sch.tensorize(block_or_loop=b314, tensor_intrin="wmma_fill_16x16x16_s32") b315 = sch.get_block(name="pad_temp_reindex_shared_wmma.matrix_a_o", func_name="main") sch.unannotate(block_or_loop=b315, ann_key="meta_schedule.auto_tensorize") - sch.tensorize(block_or_loop=b315, tensor_intrin="wmma_load_16x16x16_s8_a") + sch.tensorize(block_or_loop=b315, tensor_intrin="wmma_load_16x16x16_s8_a_shared") b316 = sch.get_block(name="p1_reindex_shared_wmma.matrix_b_o", func_name="main") sch.unannotate(block_or_loop=b316, ann_key="meta_schedule.auto_tensorize") - sch.tensorize(block_or_loop=b316, tensor_intrin="wmma_load_16x16x16_s8_b_trans") + sch.tensorize(block_or_loop=b316, tensor_intrin="wmma_load_16x16x16_s8_b_trans_shared") b317 = sch.get_block(name="conv2d_nhwc_o_update", func_name="main") sch.unannotate(block_or_loop=b317, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b317, tensor_intrin="wmma_sync_16x16x16_s8s8s32_trans") @@ -3281,7 +3281,7 @@ def apply_trace(sch: Schedule) -> None: sch.annotate( block_or_loop=b152, ann_key="meta_schedule.auto_tensorize", - ann_val="wmma_load_16x16x16_s8_a", + ann_val="wmma_load_16x16x16_s8_a_shared", ) b153 = sch.cache_read(block=b32, read_buffer_index=1, storage_scope="wmma.matrix_b") sch.compute_at(block=b153, loop=l74, preserve_unit_loops=True, index=-1) @@ -3324,7 +3324,7 @@ def apply_trace(sch: Schedule) -> None: sch.annotate( block_or_loop=b186, ann_key="meta_schedule.auto_tensorize", - ann_val="wmma_load_16x16x16_s8_b_trans", + ann_val="wmma_load_16x16x16_s8_b_trans_shared", ) sch.compute_inline(block=b11) sch.compute_inline(block=b12) diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index f9c5e22e97ce..bd46e10efaea 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -703,7 +703,7 @@ def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), " v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a"}) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a_shared"}) for ax0_1_1, ax1_1_1 in T.grid(16, 16): with T.block("pad_temp_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) @@ -718,7 +718,7 @@ def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), " v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans"}) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans_shared"}) for ax2_1, ax3_1 in T.grid(16, 16): with T.block("p1_reindex_shared_wmma.matrix_b"): v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1])