diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 68c3c3b1a8d8..d85db83f27a9 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -78,6 +78,9 @@ class TuneContextNode : public runtime::Object { v->Visit("measure_candidates", &measure_candidates); } + /*! \brief Initialize members that needs initialization with tune context. */ + void Initialize(); + static constexpr const char* _type_key = "meta_schedule.TuneContext"; TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object); }; diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 07d7c038a54f..5b796d975b30 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -112,9 +112,10 @@ class PostOrderApplyNode : public SpaceGeneratorNode { Array result{sch}; // Enumerate the schedule rules first because you can // always concat multiple schedule rules as one + Array all_blocks = BlockCollector::Collect(sch); for (ScheduleRule sch_rule : sch_rules_) { for (const tir::Schedule& sch : result) { - stack.emplace_back(sch, BlockCollector::Collect(sch)); + stack.emplace_back(sch, all_blocks); } result.clear(); @@ -137,6 +138,8 @@ class PostOrderApplyNode : public SpaceGeneratorNode { for (const tir::Schedule& sch : applied) { stack.emplace_back(sch, blocks); } + } else { + stack.emplace_back(sch, blocks); } } } diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 84c39b5ab1fc..10ca48e74b7b 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -92,28 +92,8 @@ Array SendToRunner(const Runner& runner, // return results; } -void TaskSchedulerNode::InitializeTask(int task_id) { - TuneContext task = this->tasks[task_id]; - // Derive the values. - IRModule mod = task->mod.value(); - SpaceGenerator space = task->space_generator.value(); - SearchStrategy strategy = task->search_strategy.value(); - // Initialize Modules. - space->InitializeWithTuneContext(task); - strategy->InitializeWithTuneContext(task); - // Initialize the rules. - if (task->sch_rules.defined()) - for (const ScheduleRule& sch_rule : task->sch_rules.value()) { - sch_rule->InitializeWithTuneContext(task); - } - if (task->mutators.defined()) - for (const Mutator& mutator : task->mutators.value()) { - mutator->InitializeWithTuneContext(task); - } - if (task->postprocs.defined()) - for (const Postproc& postproc : task->postprocs.value()) { - postproc->InitializeWithTuneContext(task); - } +void TaskSchedulerNode::InitializeTask(int task_id) { // + this->tasks[task_id]->Initialize(); } void TaskSchedulerNode::Tune() { diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 885a4cf2ce19..f5a4fe3b1baf 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -54,6 +54,30 @@ TuneContext::TuneContext(Optional mod, data_ = std::move(n); } +void TuneContextNode::Initialize() { + if (this->space_generator.defined()) { + this->space_generator.value()->InitializeWithTuneContext(GetRef(this)); + } + if (this->search_strategy.defined()) { + this->search_strategy.value()->InitializeWithTuneContext(GetRef(this)); + } + if (this->sch_rules.defined()) { + for (const ScheduleRule& sch_rule : sch_rules.value()) { + sch_rule->InitializeWithTuneContext(GetRef(this)); + } + } + if (this->postprocs.defined()) { + for (const Postproc& postproc : postprocs.value()) { + postproc->InitializeWithTuneContext(GetRef(this)); + } + } + if (this->mutators.defined()) { + for (const Mutator& mutator : mutators.value()) { + mutator->InitializeWithTuneContext(GetRef(this)); + } + } +} + TVM_REGISTER_NODE_TYPE(TuneContextNode); TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index 95b5ed002b27..1ce99378f728 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -298,16 +298,14 @@ def correct_trace(a, b, c, d): 'b1 = sch.get_block(name="B", func_name="main")', 'b2 = sch.get_block(name="C", func_name="main")', "sch.compute_inline(block=b1)", - 'b3 = sch.get_block(name="A", func_name="main")', - 'b4 = sch.get_block(name="C", func_name="main")', - "l5, l6 = sch.get_loops(block=b4)", - "l7, l8 = sch.split(loop=l5, factors=" + str(a) + ")", - "l9, l10 = sch.split(loop=l6, factors=" + str(b) + ")", - "sch.reorder(l7, l9, l8, l10)", - "l11, l12 = sch.get_loops(block=b3)", - "l13, l14 = sch.split(loop=l11, factors=" + str(c) + ")", - "l15, l16 = sch.split(loop=l12, factors=" + str(d) + ")", - "sch.reorder(l13, l15, l14, l16)", + "l3, l4 = sch.get_loops(block=b2)", + "l5, l6 = sch.split(loop=l3, factors=" + str(a) + ")", + "l7, l8 = sch.split(loop=l4, factors=" + str(b) + ")", + "sch.reorder(l5, l7, l6, l8)", + "l9, l10 = sch.get_loops(block=b0)", + "l11, l12 = sch.split(loop=l9, factors=" + str(c) + ")", + "l13, l14 = sch.split(loop=l10, factors=" + str(d) + ")", + "sch.reorder(l11, l13, l12, l14)", ] ) diff --git a/tests/python/unittest/test_meta_schedule_sketch_cpu.py b/tests/python/unittest/test_meta_schedule_sketch_cpu.py index cdcc518b69c2..e2e0e2c42905 100644 --- a/tests/python/unittest/test_meta_schedule_sketch_cpu.py +++ b/tests/python/unittest/test_meta_schedule_sketch_cpu.py @@ -15,8 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring - from typing import List +import sys + +import pytest from tvm.meta_schedule.testing import te_workload from tvm.meta_schedule.testing.space_generation import check_trace, create_context @@ -231,15 +233,15 @@ def test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable # pylint: disable=line-too-long expected = [ [ - 'b0 = sch.get_block(name="bias_add", func_name="main")', - 'b1 = sch.get_block(name="bn_mul", func_name="main")', - 'b2 = sch.get_block(name="bn_add", func_name="main")', + 'b0 = sch.get_block(name="compute", func_name="main")', + 'b1 = sch.get_block(name="bias_add", func_name="main")', + 'b2 = sch.get_block(name="bn_mul", func_name="main")', + 'b3 = sch.get_block(name="bn_add", func_name="main")', + "sch.compute_inline(block=b3)", "sch.compute_inline(block=b2)", "sch.compute_inline(block=b1)", - "sch.compute_inline(block=b0)", - 'b3 = sch.get_block(name="compute", func_name="main")', - "b4, = sch.get_consumers(block=b3)", - "l5, l6, l7, l8, l9, l10, l11 = sch.get_loops(block=b3)", + "b4, = sch.get_consumers(block=b0)", + "l5, l6, l7, l8, l9, l10, l11 = sch.get_loops(block=b0)", "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", "l16, l17, l18, l19 = sch.split(loop=l5, factors=[v12, v13, v14, v15])", "v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", @@ -258,15 +260,15 @@ def test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable "sch.reverse_compute_at(block=b4, loop=l41, preserve_unit_loops=1)", ], [ - 'b0 = sch.get_block(name="bias_add", func_name="main")', - 'b1 = sch.get_block(name="bn_mul", func_name="main")', - 'b2 = sch.get_block(name="bn_add", func_name="main")', + 'b0 = sch.get_block(name="compute", func_name="main")', + 'b1 = sch.get_block(name="bias_add", func_name="main")', + 'b2 = sch.get_block(name="bn_mul", func_name="main")', + 'b3 = sch.get_block(name="bn_add", func_name="main")', + "sch.compute_inline(block=b3)", "sch.compute_inline(block=b2)", "sch.compute_inline(block=b1)", - "sch.compute_inline(block=b0)", - 'b3 = sch.get_block(name="compute", func_name="main")', - "b4, = sch.get_consumers(block=b3)", - "l5, l6, l7, l8, l9, l10, l11 = sch.get_loops(block=b3)", + "b4, = sch.get_consumers(block=b0)", + "l5, l6, l7, l8, l9, l10, l11 = sch.get_loops(block=b0)", "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", "l16, l17, l18, l19 = sch.split(loop=l5, factors=[v12, v13, v14, v15])", "v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", @@ -285,14 +287,14 @@ def test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable "sch.reverse_compute_at(block=b4, loop=l40, preserve_unit_loops=1)", ], [ - 'b0 = sch.get_block(name="bias_add", func_name="main")', - 'b1 = sch.get_block(name="bn_mul", func_name="main")', - 'b2 = sch.get_block(name="bn_add", func_name="main")', + 'b0 = sch.get_block(name="compute", func_name="main")', + 'b1 = sch.get_block(name="bias_add", func_name="main")', + 'b2 = sch.get_block(name="bn_mul", func_name="main")', + 'b3 = sch.get_block(name="bn_add", func_name="main")', + "sch.compute_inline(block=b3)", "sch.compute_inline(block=b2)", "sch.compute_inline(block=b1)", - "sch.compute_inline(block=b0)", - 'b3 = sch.get_block(name="compute", func_name="main")', - "l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b3)", + "l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)", "v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", "l15, l16, l17, l18 = sch.split(loop=l4, factors=[v11, v12, v13, v14])", "v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", @@ -352,8 +354,4 @@ def test_meta_schedule_sketch_cpu_max_pool2d_nchw(): if __name__ == "__main__": - test_meta_schedule_cpu_sketch_matmul() - test_meta_schedule_cpu_sketch_matmul_relu() - test_meta_schedule_cpu_sketch_conv2d_nchw() - test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu() - test_meta_schedule_sketch_cpu_max_pool2d_nchw() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_sketch_cuda.py b/tests/python/unittest/test_meta_schedule_sketch_cuda.py index b691c23437db..87c81852d94e 100644 --- a/tests/python/unittest/test_meta_schedule_sketch_cuda.py +++ b/tests/python/unittest/test_meta_schedule_sketch_cuda.py @@ -82,36 +82,36 @@ def test_meta_schedule_cuda_sketch_matmul_relu(): expected = [ [ 'b0 = sch.get_block(name="C", func_name="main")', - 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', - "l2, l3, l4 = sch.get_loops(block=b0)", - "v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", - "l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])", - "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", - "l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])", - "v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64)", - "l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])", - "sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)", - "l31 = sch.fuse(l10, l20)", - 'sch.bind(loop=l31, thread_axis="blockIdx.x")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9, v10 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l11, l12, l13, l14, l15 = sch.split(loop=l3, factors=[v6, v7, v8, v9, v10])", + "v16, v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64)", + "l21, l22, l23, l24, l25 = sch.split(loop=l4, factors=[v16, v17, v18, v19, v20])", + "v26, v27, v28 = sch.sample_perfect_tile(loop=l5, n=3, max_innermost_factor=64)", + "l29, l30, l31 = sch.split(loop=l5, factors=[v26, v27, v28])", + "sch.reorder(l11, l21, l12, l22, l13, l23, l29, l30, l14, l24, l31, l15, l25)", "l32 = sch.fuse(l11, l21)", - 'sch.bind(loop=l32, thread_axis="vthread.x")', + 'sch.bind(loop=l32, thread_axis="blockIdx.x")', "l33 = sch.fuse(l12, l22)", - 'sch.bind(loop=l33, thread_axis="threadIdx.x")', - 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', - "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=1)", - "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", - "l41 = sch.fuse(l39, l40)", - "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)', - 'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b44, loop=l28, preserve_unit_loops=1)", - "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)", - "l51 = sch.fuse(l49, l50)", - "v52, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)', - "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=1)", - 'b54 = sch.get_block(name="compute", func_name="main")', - "sch.reverse_compute_inline(block=b54)", + 'sch.bind(loop=l33, thread_axis="vthread.x")', + "l34 = sch.fuse(l13, l23)", + 'sch.bind(loop=l34, thread_axis="threadIdx.x")', + 'b35 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b35, loop=l29, preserve_unit_loops=1)", + "l36, l37, l38, l39, l40, l41 = sch.get_loops(block=b35)", + "l42 = sch.fuse(l40, l41)", + "v43, v44 = sch.sample_perfect_tile(loop=l42, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch", ann_val=v44)', + 'b45 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b45, loop=l29, preserve_unit_loops=1)", + "l46, l47, l48, l49, l50, l51 = sch.get_loops(block=b45)", + "l52 = sch.fuse(l50, l51)", + "v53, v54 = sch.sample_perfect_tile(loop=l52, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b45, ann_key="meta_schedule.cooperative_fetch", ann_val=v54)', + "sch.reverse_compute_at(block=b2, loop=l34, preserve_unit_loops=1)", + "sch.reverse_compute_inline(block=b1)", ] ] # pylint: enable=line-too-long @@ -134,45 +134,45 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw(): # pylint: disable=line-too-long expected = [ [ - 'b0 = sch.get_block(name="compute", func_name="main")', - 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', - "l2, l3, l4, l5, l6, l7, l8 = sch.get_loops(block=b0)", - "v9, v10, v11, v12, v13 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", - "l14, l15, l16, l17, l18 = sch.split(loop=l2, factors=[v9, v10, v11, v12, v13])", - "v19, v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", - "l24, l25, l26, l27, l28 = sch.split(loop=l3, factors=[v19, v20, v21, v22, v23])", - "v29, v30, v31, v32, v33 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64)", - "l34, l35, l36, l37, l38 = sch.split(loop=l4, factors=[v29, v30, v31, v32, v33])", - "v39, v40, v41, v42, v43 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64)", - "l44, l45, l46, l47, l48 = sch.split(loop=l5, factors=[v39, v40, v41, v42, v43])", - "v49, v50, v51 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64)", - "l52, l53, l54 = sch.split(loop=l6, factors=[v49, v50, v51])", - "v55, v56, v57 = sch.sample_perfect_tile(loop=l7, n=3, max_innermost_factor=64)", - "l58, l59, l60 = sch.split(loop=l7, factors=[v55, v56, v57])", - "v61, v62, v63 = sch.sample_perfect_tile(loop=l8, n=3, max_innermost_factor=64)", - "l64, l65, l66 = sch.split(loop=l8, factors=[v61, v62, v63])", - "sch.reorder(l14, l24, l34, l44, l15, l25, l35, l45, l16, l26, l36, l46, l52, l58, l64, l53, l59, l65, l17, l27, l37, l47, l54, l60, l66, l18, l28, l38, l48)", - "l67 = sch.fuse(l14, l24, l34, l44)", - 'sch.bind(loop=l67, thread_axis="blockIdx.x")', + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local")', + "l3, l4, l5, l6, l7, l8, l9 = sch.get_loops(block=b1)", + "v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l15, l16, l17, l18, l19 = sch.split(loop=l3, factors=[v10, v11, v12, v13, v14])", + "v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64)", + "l25, l26, l27, l28, l29 = sch.split(loop=l4, factors=[v20, v21, v22, v23, v24])", + "v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64)", + "l35, l36, l37, l38, l39 = sch.split(loop=l5, factors=[v30, v31, v32, v33, v34])", + "v40, v41, v42, v43, v44 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64)", + "l45, l46, l47, l48, l49 = sch.split(loop=l6, factors=[v40, v41, v42, v43, v44])", + "v50, v51, v52 = sch.sample_perfect_tile(loop=l7, n=3, max_innermost_factor=64)", + "l53, l54, l55 = sch.split(loop=l7, factors=[v50, v51, v52])", + "v56, v57, v58 = sch.sample_perfect_tile(loop=l8, n=3, max_innermost_factor=64)", + "l59, l60, l61 = sch.split(loop=l8, factors=[v56, v57, v58])", + "v62, v63, v64 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64)", + "l65, l66, l67 = sch.split(loop=l9, factors=[v62, v63, v64])", + "sch.reorder(l15, l25, l35, l45, l16, l26, l36, l46, l17, l27, l37, l47, l53, l59, l65, l54, l60, l66, l18, l28, l38, l48, l55, l61, l67, l19, l29, l39, l49)", "l68 = sch.fuse(l15, l25, l35, l45)", - 'sch.bind(loop=l68, thread_axis="vthread.x")', + 'sch.bind(loop=l68, thread_axis="blockIdx.x")', "l69 = sch.fuse(l16, l26, l36, l46)", - 'sch.bind(loop=l69, thread_axis="threadIdx.x")', - 'b70 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', - "sch.compute_at(block=b70, loop=l64, preserve_unit_loops=1)", - "l71, l72, l73, l74, l75, l76, l77, l78, l79, l80 = sch.get_loops(block=b70)", - "l81 = sch.fuse(l77, l78, l79, l80)", - "v82, v83 = sch.sample_perfect_tile(loop=l81, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v83)', - 'b84 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b84, loop=l64, preserve_unit_loops=1)", - "l85, l86, l87, l88, l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b84)", - "l95 = sch.fuse(l91, l92, l93, l94)", - "v96, v97 = sch.sample_perfect_tile(loop=l95, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b84, ann_key="meta_schedule.cooperative_fetch", ann_val=v97)', - "sch.reverse_compute_at(block=b1, loop=l69, preserve_unit_loops=1)", - 'b98 = sch.get_block(name="pad_temp", func_name="main")', - "sch.compute_inline(block=b98)", + 'sch.bind(loop=l69, thread_axis="vthread.x")', + "l70 = sch.fuse(l17, l27, l37, l47)", + 'sch.bind(loop=l70, thread_axis="threadIdx.x")', + 'b71 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b71, loop=l65, preserve_unit_loops=1)", + "l72, l73, l74, l75, l76, l77, l78, l79, l80, l81 = sch.get_loops(block=b71)", + "l82 = sch.fuse(l78, l79, l80, l81)", + "v83, v84 = sch.sample_perfect_tile(loop=l82, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch", ann_val=v84)', + 'b85 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b85, loop=l65, preserve_unit_loops=1)", + "l86, l87, l88, l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b85)", + "l96 = sch.fuse(l92, l93, l94, l95)", + "v97, v98 = sch.sample_perfect_tile(loop=l96, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v98)', + "sch.reverse_compute_at(block=b2, loop=l70, preserve_unit_loops=1)", + "sch.compute_inline(block=b0)", ] ] # pylint: enable=line-too-long @@ -202,53 +202,53 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disabl # pylint: disable=line-too-long expected = [ [ - 'b0 = sch.get_block(name="bias_add", func_name="main")', - 'b1 = sch.get_block(name="bn_mul", func_name="main")', - 'b2 = sch.get_block(name="bn_add", func_name="main")', + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.get_block(name="bias_add", func_name="main")', + 'b3 = sch.get_block(name="bn_mul", func_name="main")', + 'b4 = sch.get_block(name="bn_add", func_name="main")', + 'b5 = sch.get_block(name="compute_1", func_name="main")', + "sch.compute_inline(block=b4)", + "sch.compute_inline(block=b3)", "sch.compute_inline(block=b2)", - "sch.compute_inline(block=b1)", - "sch.compute_inline(block=b0)", - 'b3 = sch.get_block(name="compute", func_name="main")', - 'b4 = sch.cache_write(block=b3, write_buffer_index=0, storage_scope="local")', - "l5, l6, l7, l8, l9, l10, l11 = sch.get_loops(block=b3)", - "v12, v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64)", - "l17, l18, l19, l20, l21 = sch.split(loop=l5, factors=[v12, v13, v14, v15, v16])", - "v22, v23, v24, v25, v26 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64)", - "l27, l28, l29, l30, l31 = sch.split(loop=l6, factors=[v22, v23, v24, v25, v26])", - "v32, v33, v34, v35, v36 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64)", - "l37, l38, l39, l40, l41 = sch.split(loop=l7, factors=[v32, v33, v34, v35, v36])", - "v42, v43, v44, v45, v46 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64)", - "l47, l48, l49, l50, l51 = sch.split(loop=l8, factors=[v42, v43, v44, v45, v46])", - "v52, v53, v54 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64)", - "l55, l56, l57 = sch.split(loop=l9, factors=[v52, v53, v54])", - "v58, v59, v60 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64)", - "l61, l62, l63 = sch.split(loop=l10, factors=[v58, v59, v60])", - "v64, v65, v66 = sch.sample_perfect_tile(loop=l11, n=3, max_innermost_factor=64)", - "l67, l68, l69 = sch.split(loop=l11, factors=[v64, v65, v66])", - "sch.reorder(l17, l27, l37, l47, l18, l28, l38, l48, l19, l29, l39, l49, l55, l61, l67, l56, l62, l68, l20, l30, l40, l50, l57, l63, l69, l21, l31, l41, l51)", - "l70 = sch.fuse(l17, l27, l37, l47)", - 'sch.bind(loop=l70, thread_axis="blockIdx.x")', - "l71 = sch.fuse(l18, l28, l38, l48)", - 'sch.bind(loop=l71, thread_axis="vthread.x")', + 'b6 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local")', + "l7, l8, l9, l10, l11, l12, l13 = sch.get_loops(block=b1)", + "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64)", + "l19, l20, l21, l22, l23 = sch.split(loop=l7, factors=[v14, v15, v16, v17, v18])", + "v24, v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64)", + "l29, l30, l31, l32, l33 = sch.split(loop=l8, factors=[v24, v25, v26, v27, v28])", + "v34, v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64)", + "l39, l40, l41, l42, l43 = sch.split(loop=l9, factors=[v34, v35, v36, v37, v38])", + "v44, v45, v46, v47, v48 = sch.sample_perfect_tile(loop=l10, n=5, max_innermost_factor=64)", + "l49, l50, l51, l52, l53 = sch.split(loop=l10, factors=[v44, v45, v46, v47, v48])", + "v54, v55, v56 = sch.sample_perfect_tile(loop=l11, n=3, max_innermost_factor=64)", + "l57, l58, l59 = sch.split(loop=l11, factors=[v54, v55, v56])", + "v60, v61, v62 = sch.sample_perfect_tile(loop=l12, n=3, max_innermost_factor=64)", + "l63, l64, l65 = sch.split(loop=l12, factors=[v60, v61, v62])", + "v66, v67, v68 = sch.sample_perfect_tile(loop=l13, n=3, max_innermost_factor=64)", + "l69, l70, l71 = sch.split(loop=l13, factors=[v66, v67, v68])", + "sch.reorder(l19, l29, l39, l49, l20, l30, l40, l50, l21, l31, l41, l51, l57, l63, l69, l58, l64, l70, l22, l32, l42, l52, l59, l65, l71, l23, l33, l43, l53)", "l72 = sch.fuse(l19, l29, l39, l49)", - 'sch.bind(loop=l72, thread_axis="threadIdx.x")', - 'b73 = sch.cache_read(block=b3, read_buffer_index=1, storage_scope="shared")', - "sch.compute_at(block=b73, loop=l67, preserve_unit_loops=1)", - "l74, l75, l76, l77, l78, l79, l80, l81, l82, l83 = sch.get_loops(block=b73)", - "l84 = sch.fuse(l80, l81, l82, l83)", - "v85, v86 = sch.sample_perfect_tile(loop=l84, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v86)', - 'b87 = sch.cache_read(block=b3, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b87, loop=l67, preserve_unit_loops=1)", - "l88, l89, l90, l91, l92, l93, l94, l95, l96, l97 = sch.get_loops(block=b87)", - "l98 = sch.fuse(l94, l95, l96, l97)", - "v99, v100 = sch.sample_perfect_tile(loop=l98, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b87, ann_key="meta_schedule.cooperative_fetch", ann_val=v100)', - "sch.reverse_compute_at(block=b4, loop=l72, preserve_unit_loops=1)", - 'b101 = sch.get_block(name="pad_temp", func_name="main")', - 'b102 = sch.get_block(name="compute_1", func_name="main")', - "sch.reverse_compute_inline(block=b102)", - "sch.compute_inline(block=b101)", + 'sch.bind(loop=l72, thread_axis="blockIdx.x")', + "l73 = sch.fuse(l20, l30, l40, l50)", + 'sch.bind(loop=l73, thread_axis="vthread.x")', + "l74 = sch.fuse(l21, l31, l41, l51)", + 'sch.bind(loop=l74, thread_axis="threadIdx.x")', + 'b75 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b75, loop=l69, preserve_unit_loops=1)", + "l76, l77, l78, l79, l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b75)", + "l86 = sch.fuse(l82, l83, l84, l85)", + "v87, v88 = sch.sample_perfect_tile(loop=l86, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b75, ann_key="meta_schedule.cooperative_fetch", ann_val=v88)', + 'b89 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b89, loop=l69, preserve_unit_loops=1)", + "l90, l91, l92, l93, l94, l95, l96, l97, l98, l99 = sch.get_loops(block=b89)", + "l100 = sch.fuse(l96, l97, l98, l99)", + "v101, v102 = sch.sample_perfect_tile(loop=l100, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b89, ann_key="meta_schedule.cooperative_fetch", ann_val=v102)', + "sch.reverse_compute_at(block=b6, loop=l74, preserve_unit_loops=1)", + "sch.reverse_compute_inline(block=b5)", + "sch.compute_inline(block=b0)", ] ] # pylint: enable=line-too-long