Skip to content

Commit

Permalink
[Meta Schedule][M3a] TaskScheduler (apache#9154)
Browse files Browse the repository at this point in the history
* Add docs.

* Add TaskScheduler.

Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>

* Retrigger CI after hotfix.

Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
  • Loading branch information
7 people authored and ylc committed Jan 13, 2022
1 parent a921ab8 commit 825823c
Show file tree
Hide file tree
Showing 15 changed files with 1,005 additions and 7 deletions.
6 changes: 3 additions & 3 deletions include/tvm/meta_schedule/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
namespace tvm {
namespace meta_schedule {

/*! \brief The runner's input. */
/*! \brief Runner's input containing path of artifact, type of device and argument info. */
class RunnerInputNode : public runtime::Object {
public:
/*! \brief The path to the built artifact. */
Expand Down Expand Up @@ -61,7 +61,7 @@ class RunnerInput : public runtime::ObjectRef {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerInput, runtime::ObjectRef, RunnerInputNode);
};

/*! \brief The runner's output. */
/*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */
class RunnerResultNode : public runtime::Object {
public:
/*! \brief The run time in seconds.*/
Expand Down Expand Up @@ -96,7 +96,7 @@ class RunnerResult : public runtime::ObjectRef {
/*!
* \brief A class to asynchronously fetch runner's output.
* \note The API design is consistent with python's concurrent.futures.Future:
* https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future
* https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future
*/
class RunnerFutureNode : public runtime::Object {
public:
Expand Down
220 changes: 220 additions & 0 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_TASK_SCHEDULER_H_
#define TVM_META_SCHEDULE_TASK_SCHEDULER_H_

#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/database.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/tune_context.h>

namespace tvm {
namespace meta_schedule {

/*!
* \brief The abstract interface of task schedulers.
* \note The relationship between SpaceGenerator and other classes are as follows:
┌──────────────────────────────────────────────────────────────┐
┌──┴───────────────────────────────────────────────────────────┐ │
┌──┴────────────────── Tune Context ───────────────────────────┐ │ │
│ ┌─────────────────────┐ │ │ │
│ │ │ Generate │ │ │
│ │ Space Generator ├──────────────┐ │ │ │
│ │ │ │ │ │ │
│ └─────────────────────┘ ▼ │ │ │
│ Design Space │ │ │
│ ┌─────────────────────┐ │ │ │ │
│ Generate │ │ Pretuning │ │ │ │
│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │
│ │ │ │ │ ├──┘
│ │ └─────────────────────┘ ├──┘
└────┼─────────────────────────────────────────────────────────┘
┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐
│ │ ┌───────────┐ │
│ │ Send to │ │ Send to │
│ ▼ ┌─────────────►│ Builder ├──────────┐ │
│ Measure Candidate │ Builder │ │ Runner │ │
│ │ │ └───────────┘ │ │
│ │ ┌────────────┴────────┐ │ │
│ │ │ │ ┌───────────┐ │ │
│ └────►│ Task Scheduler │ │ │ │ │
│ │ │ │ Runner │◄─────────┘ │
│ └─────────────────────┘ │ │ │
│ ▲ └─────┬─────┘ │
│ │ │ │
│ └─── Runner Future ◄────┘ │
└─────────────────────────────────────────────────────────────────────┘
*/
class TaskSchedulerNode : public runtime::Object {
public:
/*! \brief The tasks to be tuned */
Array<TuneContext> tasks;
/*! \brief The builder of the scheduler. */
Builder builder{nullptr};
/*! \brief The runner of the scheduler. */
Runner runner{nullptr};
/*! \brief The database of the scheduler. */
Database database{nullptr};

/*! \brief The default desctructor. */
virtual ~TaskSchedulerNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("tasks", &tasks);
v->Visit("builder", &builder);
v->Visit("runner", &runner);
v->Visit("database", &database);
}

/*! \brief Auto-tuning. */
virtual void Tune();

/*!
* \brief Set specific task to be stopped.
* \param task_id The task id to be stopped.
*/
virtual void SetTaskStopped(int task_id);

/*!
* \brief Check whether the task is running.
* \param task_id The task id to be checked.
* \return Whether the task is running.
*/
virtual bool IsTaskRunning(int task_id);

/*!
* \brief Wait until the task is finished.
* \param task_id The task id to be joined.
*/
virtual void JoinRunningTask(int task_id);

/*!
* \brief Fetch the next task id.
* \return The next task id.
*/
virtual int NextTaskId() = 0;

static constexpr const char* _type_key = "meta_schedule.TaskScheduler";
TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object);
};

/*! \brief The task scheduler with customized methods on the python-side. */
class PyTaskSchedulerNode : public TaskSchedulerNode {
public:
/*! \brief The function type of `Tune` method. */
using FTune = runtime::TypedPackedFunc<void()>;

/*!
* \brief The function type of `SetTaskStopped` method.
* \param task_id The task id to be stopped.
*/
using FSetTaskStopped = runtime::TypedPackedFunc<void(int)>;

/*!
* \brief The function type of `IsTaskRunning` method.
* \param task_id The task id to be checked.
* \return Whether the task is running.
*/
using FIsTaskRunning = runtime::TypedPackedFunc<bool(int)>;

/*!
* \brief The function type of `JoinRunningTask` method.
* \param task_id The task id to be joined.
*/
using FJoinRunningTask = runtime::TypedPackedFunc<void(int)>;

/*!
* \brief The function type of `NextTaskId` method.
* \return The next task id.
*/
using FNextTaskId = runtime::TypedPackedFunc<int()>;

/*! \brief The packed function to the `Tune` funcion. */
FTune f_tune;
/*! \brief The packed function to the `SetTaskStopped` function. */
FSetTaskStopped f_set_task_stopped;
/*! \brief The packed function to the `IsTaskRunning` function. */
FIsTaskRunning f_is_task_running;
/*! \brief The packed function to the `JoinRunningTask` function. */
FJoinRunningTask f_join_running_task;
/*! \brief The packed function to the `NextTaskId` function. */
FNextTaskId f_next_task_id;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_tune` is not visited
// `f_set_task_stopped` is not visited
// `f_is_task_running` is not visited
// `f_join_running_task` is not visited
// `f_next_task_id` is not visited
}

void Tune() final { //
f_tune();
}

void SetTaskStopped(int task_id) final { //
f_set_task_stopped(task_id);
}

bool IsTaskRunning(int task_id) final { //
return f_is_task_running(task_id);
}

void JoinRunningTask(int task_id) final { //
f_join_running_task(task_id);
}

int NextTaskId() final { //
return f_next_task_id();
}

static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler";
TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode);
};

/*!
* \brief Managed reference to TaskSchedulerNode.
* \sa TaskSchedulerNode
*/
class TaskScheduler : public runtime::ObjectRef {
public:
/*!
* \brief Create a task scheduler that fetches tasks in a round-robin fashion.
* \param tasks The tasks to be tuned.
* \param builder The builder of the scheduler.
* \param runner The runner of the scheduler.
* \param database The database of the scheduler.
*/
TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, Builder builder, Runner runner,
Database database);
TVM_DLL static TaskScheduler PyTaskScheduler(
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
PyTaskSchedulerNode::FNextTaskId f_next_task_id);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_TASK_SCHEDULER_H_
15 changes: 15 additions & 0 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,33 @@ class TuneContextNode : public runtime::Object {
Optional<Target> target;
/*! \brief The design space generator. */
Optional<SpaceGenerator> space_generator;
/*! \brief The search strategy. */
Optional<SearchStrategy> search_strategy;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
/*! \brief The random state. */
support::LinearCongruentialEngine::TRandState rand_state;
/*! \brief The number of threads to be used. */
int num_threads;

/*! \brief Whether the tuning task has been stopped or finished. */
bool is_stopped;
/*! \brief Packed functions to fetch the runner results asynchronously. */
Optional<Array<RunnerFuture>> runner_futures;
/*! \brief The measure candidates. */
Optional<Array<MeasureCandidate>> measure_candidates;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("mod", &mod);
v->Visit("target", &target);
v->Visit("space_generator", &space_generator);
v->Visit("search_strategy", &search_strategy);
v->Visit("task_name", &task_name);
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
v->Visit("is_stopped", &is_stopped);
v->Visit("runner_futures", &runner_futures);
v->Visit("measure_candidates", &measure_candidates);
}

static constexpr const char* _type_key = "meta_schedule.TuneContext";
Expand All @@ -67,13 +80,15 @@ class TuneContext : public runtime::ObjectRef {
* \param mod The workload to be tuned.
* \param target The target to be tuned for.
* \param space_generator The design space generator.
* \param search_strategy The search strategy.
* \param task_name The name of the tuning task.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
*/
TVM_DLL explicit TuneContext(Optional<IRModule> mod, //
Optional<Target> target, //
Optional<SpaceGenerator> space_generator, //
Optional<SearchStrategy> search_strategy, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
The tvm.meta_schedule.database package.
The database that stores serialized tuning records and workloads
"""
from .database import Database, PyDatabase, TuningRecord
from .database import Database, PyDatabase, TuningRecord, Workload
from .json_database import JSONDatabase
6 changes: 5 additions & 1 deletion python/tvm/meta_schedule/search_strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Search Strategy"""
"""
The tvm.meta_schedule.search_strategy package.
Meta Schedule search strategy utilizes the design spaces given
to generate measure candidates.
"""

from .search_strategy import SearchStrategy, PySearchStrategy
from .replay_trace import ReplayTrace
24 changes: 24 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
The tvm.meta_schedule.task_scheduler package.
Meta Schedule task scheduler that manage the task scheduling
for measure candidates generation and measurement, then save
records to the database.
"""
from .task_scheduler import TaskScheduler, PyTaskScheduler
from .round_robin import RoundRobin
Loading

0 comments on commit 825823c

Please sign in to comment.