Skip to content

Commit

Permalink
[Meta Schedule] Add Helper Function & Minor Modification (apache#512)
Browse files Browse the repository at this point in the history
* Add TuneContext helper func for init.

* Fix tests.

* Rebase.
  • Loading branch information
zxybazh committed Nov 11, 2021
1 parent 996cf62 commit 4e1c53a
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 169 deletions.
3 changes: 3 additions & 0 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class TuneContextNode : public runtime::Object {
v->Visit("measure_candidates", &measure_candidates);
}

/*! \brief Initialize members that needs initialization with tune context. */
void Initialize();

static constexpr const char* _type_key = "meta_schedule.TuneContext";
TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object);
};
Expand Down
5 changes: 4 additions & 1 deletion src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
Array<tir::Schedule> result{sch};
// Enumerate the schedule rules first because you can
// always concat multiple schedule rules as one
Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch);
for (ScheduleRule sch_rule : sch_rules_) {
for (const tir::Schedule& sch : result) {
stack.emplace_back(sch, BlockCollector::Collect(sch));
stack.emplace_back(sch, all_blocks);
}
result.clear();

Expand All @@ -137,6 +138,8 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
for (const tir::Schedule& sch : applied) {
stack.emplace_back(sch, blocks);
}
} else {
stack.emplace_back(sch, blocks);
}
}
}
Expand Down
24 changes: 2 additions & 22 deletions src/meta_schedule/task_scheduler/task_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,28 +92,8 @@ Array<RunnerFuture> SendToRunner(const Runner& runner, //
return results;
}

void TaskSchedulerNode::InitializeTask(int task_id) {
TuneContext task = this->tasks[task_id];
// Derive the values.
IRModule mod = task->mod.value();
SpaceGenerator space = task->space_generator.value();
SearchStrategy strategy = task->search_strategy.value();
// Initialize Modules.
space->InitializeWithTuneContext(task);
strategy->InitializeWithTuneContext(task);
// Initialize the rules.
if (task->sch_rules.defined())
for (const ScheduleRule& sch_rule : task->sch_rules.value()) {
sch_rule->InitializeWithTuneContext(task);
}
if (task->mutators.defined())
for (const Mutator& mutator : task->mutators.value()) {
mutator->InitializeWithTuneContext(task);
}
if (task->postprocs.defined())
for (const Postproc& postproc : task->postprocs.value()) {
postproc->InitializeWithTuneContext(task);
}
void TaskSchedulerNode::InitializeTask(int task_id) { //
this->tasks[task_id]->Initialize();
}

void TaskSchedulerNode::Tune() {
Expand Down
24 changes: 24 additions & 0 deletions src/meta_schedule/tune_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,30 @@ TuneContext::TuneContext(Optional<IRModule> mod,
data_ = std::move(n);
}

void TuneContextNode::Initialize() {
if (this->space_generator.defined()) {
this->space_generator.value()->InitializeWithTuneContext(GetRef<TuneContext>(this));
}
if (this->search_strategy.defined()) {
this->search_strategy.value()->InitializeWithTuneContext(GetRef<TuneContext>(this));
}
if (this->sch_rules.defined()) {
for (const ScheduleRule& sch_rule : sch_rules.value()) {
sch_rule->InitializeWithTuneContext(GetRef<TuneContext>(this));
}
}
if (this->postprocs.defined()) {
for (const Postproc& postproc : postprocs.value()) {
postproc->InitializeWithTuneContext(GetRef<TuneContext>(this));
}
}
if (this->mutators.defined()) {
for (const Mutator& mutator : mutators.value()) {
mutator->InitializeWithTuneContext(GetRef<TuneContext>(this));
}
}
}

TVM_REGISTER_NODE_TYPE(TuneContextNode);

TVM_REGISTER_GLOBAL("meta_schedule.TuneContext")
Expand Down
18 changes: 8 additions & 10 deletions tests/python/unittest/test_meta_schedule_post_order_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,16 +298,14 @@ def correct_trace(a, b, c, d):
'b1 = sch.get_block(name="B", func_name="main")',
'b2 = sch.get_block(name="C", func_name="main")',
"sch.compute_inline(block=b1)",
'b3 = sch.get_block(name="A", func_name="main")',
'b4 = sch.get_block(name="C", func_name="main")',
"l5, l6 = sch.get_loops(block=b4)",
"l7, l8 = sch.split(loop=l5, factors=" + str(a) + ")",
"l9, l10 = sch.split(loop=l6, factors=" + str(b) + ")",
"sch.reorder(l7, l9, l8, l10)",
"l11, l12 = sch.get_loops(block=b3)",
"l13, l14 = sch.split(loop=l11, factors=" + str(c) + ")",
"l15, l16 = sch.split(loop=l12, factors=" + str(d) + ")",
"sch.reorder(l13, l15, l14, l16)",
"l3, l4 = sch.get_loops(block=b2)",
"l5, l6 = sch.split(loop=l3, factors=" + str(a) + ")",
"l7, l8 = sch.split(loop=l4, factors=" + str(b) + ")",
"sch.reorder(l5, l7, l6, l8)",
"l9, l10 = sch.get_loops(block=b0)",
"l11, l12 = sch.split(loop=l9, factors=" + str(c) + ")",
"l13, l14 = sch.split(loop=l10, factors=" + str(d) + ")",
"sch.reorder(l11, l13, l12, l14)",
]
)

Expand Down
50 changes: 24 additions & 26 deletions tests/python/unittest/test_meta_schedule_sketch_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring

from typing import List
import sys

import pytest

from tvm.meta_schedule.testing import te_workload
from tvm.meta_schedule.testing.space_generation import check_trace, create_context
Expand Down Expand Up @@ -231,15 +233,15 @@ def test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable
# pylint: disable=line-too-long
expected = [
[
'b0 = sch.get_block(name="bias_add", func_name="main")',
'b1 = sch.get_block(name="bn_mul", func_name="main")',
'b2 = sch.get_block(name="bn_add", func_name="main")',
'b0 = sch.get_block(name="compute", func_name="main")',
'b1 = sch.get_block(name="bias_add", func_name="main")',
'b2 = sch.get_block(name="bn_mul", func_name="main")',
'b3 = sch.get_block(name="bn_add", func_name="main")',
"sch.compute_inline(block=b3)",
"sch.compute_inline(block=b2)",
"sch.compute_inline(block=b1)",
"sch.compute_inline(block=b0)",
'b3 = sch.get_block(name="compute", func_name="main")',
"b4, = sch.get_consumers(block=b3)",
"l5, l6, l7, l8, l9, l10, l11 = sch.get_loops(block=b3)",
"b4, = sch.get_consumers(block=b0)",
"l5, l6, l7, l8, l9, l10, l11 = sch.get_loops(block=b0)",
"v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)",
"l16, l17, l18, l19 = sch.split(loop=l5, factors=[v12, v13, v14, v15])",
"v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)",
Expand All @@ -258,15 +260,15 @@ def test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable
"sch.reverse_compute_at(block=b4, loop=l41, preserve_unit_loops=1)",
],
[
'b0 = sch.get_block(name="bias_add", func_name="main")',
'b1 = sch.get_block(name="bn_mul", func_name="main")',
'b2 = sch.get_block(name="bn_add", func_name="main")',
'b0 = sch.get_block(name="compute", func_name="main")',
'b1 = sch.get_block(name="bias_add", func_name="main")',
'b2 = sch.get_block(name="bn_mul", func_name="main")',
'b3 = sch.get_block(name="bn_add", func_name="main")',
"sch.compute_inline(block=b3)",
"sch.compute_inline(block=b2)",
"sch.compute_inline(block=b1)",
"sch.compute_inline(block=b0)",
'b3 = sch.get_block(name="compute", func_name="main")',
"b4, = sch.get_consumers(block=b3)",
"l5, l6, l7, l8, l9, l10, l11 = sch.get_loops(block=b3)",
"b4, = sch.get_consumers(block=b0)",
"l5, l6, l7, l8, l9, l10, l11 = sch.get_loops(block=b0)",
"v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)",
"l16, l17, l18, l19 = sch.split(loop=l5, factors=[v12, v13, v14, v15])",
"v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)",
Expand All @@ -285,14 +287,14 @@ def test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable
"sch.reverse_compute_at(block=b4, loop=l40, preserve_unit_loops=1)",
],
[
'b0 = sch.get_block(name="bias_add", func_name="main")',
'b1 = sch.get_block(name="bn_mul", func_name="main")',
'b2 = sch.get_block(name="bn_add", func_name="main")',
'b0 = sch.get_block(name="compute", func_name="main")',
'b1 = sch.get_block(name="bias_add", func_name="main")',
'b2 = sch.get_block(name="bn_mul", func_name="main")',
'b3 = sch.get_block(name="bn_add", func_name="main")',
"sch.compute_inline(block=b3)",
"sch.compute_inline(block=b2)",
"sch.compute_inline(block=b1)",
"sch.compute_inline(block=b0)",
'b3 = sch.get_block(name="compute", func_name="main")',
"l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b3)",
"l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0)",
"v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)",
"l15, l16, l17, l18 = sch.split(loop=l4, factors=[v11, v12, v13, v14])",
"v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)",
Expand Down Expand Up @@ -352,8 +354,4 @@ def test_meta_schedule_sketch_cpu_max_pool2d_nchw():


if __name__ == "__main__":
test_meta_schedule_cpu_sketch_matmul()
test_meta_schedule_cpu_sketch_matmul_relu()
test_meta_schedule_cpu_sketch_conv2d_nchw()
test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu()
test_meta_schedule_sketch_cpu_max_pool2d_nchw()
sys.exit(pytest.main([__file__] + sys.argv[1:]))
Loading

0 comments on commit 4e1c53a

Please sign in to comment.