Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule][M3a] TaskScheduler #9154

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/tvm/meta_schedule/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
namespace tvm {
namespace meta_schedule {

/*! \brief The runner's input. */
/*! \brief Runner's input containing path of artifact, type of device and argument info. */
class RunnerInputNode : public runtime::Object {
public:
/*! \brief The path to the built artifact. */
Expand Down Expand Up @@ -61,7 +61,7 @@ class RunnerInput : public runtime::ObjectRef {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerInput, runtime::ObjectRef, RunnerInputNode);
};

/*! \brief The runner's output. */
/*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */
class RunnerResultNode : public runtime::Object {
public:
/*! \brief The run time in seconds.*/
Expand Down Expand Up @@ -96,7 +96,7 @@ class RunnerResult : public runtime::ObjectRef {
/*!
* \brief A class to asynchronously fetch runner's output.
* \note The API design is consistent with python's concurrent.futures.Future:
* https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future
* https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future
*/
class RunnerFutureNode : public runtime::Object {
public:
Expand Down
220 changes: 220 additions & 0 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_TASK_SCHEDULER_H_
#define TVM_META_SCHEDULE_TASK_SCHEDULER_H_

#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/database.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/tune_context.h>

namespace tvm {
namespace meta_schedule {

/*!
* \brief The abstract interface of task schedulers.
* \note The relationship between SpaceGenerator and other classes are as follows:
┌──────────────────────────────────────────────────────────────┐
┌──┴───────────────────────────────────────────────────────────┐ │
┌──┴────────────────── Tune Context ───────────────────────────┐ │ │
│ ┌─────────────────────┐ │ │ │
│ │ │ Generate │ │ │
│ │ Space Generator ├──────────────┐ │ │ │
│ │ │ │ │ │ │
│ └─────────────────────┘ ▼ │ │ │
│ Design Space │ │ │
│ ┌─────────────────────┐ │ │ │ │
│ Generate │ │ Pretuning │ │ │ │
│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │
│ │ │ │ │ ├──┘
│ │ └─────────────────────┘ ├──┘
└────┼─────────────────────────────────────────────────────────┘
┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐
│ │ ┌───────────┐ │
│ │ Send to │ │ Send to │
│ ▼ ┌─────────────►│ Builder ├──────────┐ │
│ Measure Candidate │ Builder │ │ Runner │ │
│ │ │ └───────────┘ │ │
│ │ ┌────────────┴────────┐ │ │
│ │ │ │ ┌───────────┐ │ │
│ └────►│ Task Scheduler │ │ │ │ │
│ │ │ │ Runner │◄─────────┘ │
│ └─────────────────────┘ │ │ │
│ ▲ └─────┬─────┘ │
│ │ │ │
│ └─── Runner Future ◄────┘ │
└─────────────────────────────────────────────────────────────────────┘
*/
class TaskSchedulerNode : public runtime::Object {
public:
/*! \brief The tasks to be tuned */
Array<TuneContext> tasks;
/*! \brief The builder of the scheduler. */
Builder builder{nullptr};
/*! \brief The runner of the scheduler. */
Runner runner{nullptr};
/*! \brief The database of the scheduler. */
Database database{nullptr};

/*! \brief The default desctructor. */
virtual ~TaskSchedulerNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("tasks", &tasks);
v->Visit("builder", &builder);
v->Visit("runner", &runner);
v->Visit("database", &database);
}

/*! \brief Auto-tuning. */
virtual void Tune();

/*!
* \brief Set specific task to be stopped.
* \param task_id The task id to be stopped.
*/
virtual void SetTaskStopped(int task_id);

/*!
* \brief Check whether the task is running.
* \param task_id The task id to be checked.
* \return Whether the task is running.
*/
virtual bool IsTaskRunning(int task_id);

/*!
* \brief Wait until the task is finished.
* \param task_id The task id to be joined.
*/
virtual void JoinRunningTask(int task_id);

/*!
* \brief Fetch the next task id.
* \return The next task id.
*/
virtual int NextTaskId() = 0;

static constexpr const char* _type_key = "meta_schedule.TaskScheduler";
TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object);
};

/*! \brief The task scheduler with customized methods on the python-side. */
class PyTaskSchedulerNode : public TaskSchedulerNode {
public:
/*! \brief The function type of `Tune` method. */
using FTune = runtime::TypedPackedFunc<void()>;

/*!
* \brief The function type of `SetTaskStopped` method.
* \param task_id The task id to be stopped.
*/
using FSetTaskStopped = runtime::TypedPackedFunc<void(int)>;

/*!
* \brief The function type of `IsTaskRunning` method.
* \param task_id The task id to be checked.
* \return Whether the task is running.
*/
using FIsTaskRunning = runtime::TypedPackedFunc<bool(int)>;

/*!
* \brief The function type of `JoinRunningTask` method.
* \param task_id The task id to be joined.
*/
using FJoinRunningTask = runtime::TypedPackedFunc<void(int)>;

/*!
* \brief The function type of `NextTaskId` method.
* \return The next task id.
*/
using FNextTaskId = runtime::TypedPackedFunc<int()>;

/*! \brief The packed function to the `Tune` funcion. */
FTune f_tune;
/*! \brief The packed function to the `SetTaskStopped` function. */
FSetTaskStopped f_set_task_stopped;
/*! \brief The packed function to the `IsTaskRunning` function. */
FIsTaskRunning f_is_task_running;
/*! \brief The packed function to the `JoinRunningTask` function. */
FJoinRunningTask f_join_running_task;
/*! \brief The packed function to the `NextTaskId` function. */
FNextTaskId f_next_task_id;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_tune` is not visited
// `f_set_task_stopped` is not visited
// `f_is_task_running` is not visited
// `f_join_running_task` is not visited
// `f_next_task_id` is not visited
}

void Tune() final { //
f_tune();
}

void SetTaskStopped(int task_id) final { //
f_set_task_stopped(task_id);
}

bool IsTaskRunning(int task_id) final { //
return f_is_task_running(task_id);
}

void JoinRunningTask(int task_id) final { //
f_join_running_task(task_id);
}

int NextTaskId() final { //
return f_next_task_id();
}

static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler";
TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode);
};

/*!
* \brief Managed reference to TaskSchedulerNode.
* \sa TaskSchedulerNode
*/
class TaskScheduler : public runtime::ObjectRef {
public:
/*!
* \brief Create a task scheduler that fetches tasks in a round-robin fashion.
* \param tasks The tasks to be tuned.
* \param builder The builder of the scheduler.
* \param runner The runner of the scheduler.
* \param database The database of the scheduler.
*/
TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, Builder builder, Runner runner,
Database database);
TVM_DLL static TaskScheduler PyTaskScheduler(
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
PyTaskSchedulerNode::FNextTaskId f_next_task_id);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_TASK_SCHEDULER_H_
15 changes: 15 additions & 0 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,33 @@ class TuneContextNode : public runtime::Object {
Optional<Target> target;
/*! \brief The design space generator. */
Optional<SpaceGenerator> space_generator;
/*! \brief The search strategy. */
Optional<SearchStrategy> search_strategy;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
/*! \brief The random state. */
support::LinearCongruentialEngine::TRandState rand_state;
/*! \brief The number of threads to be used. */
int num_threads;

/*! \brief Whether the tuning task has been stopped or finished. */
bool is_stopped;
/*! \brief Packed functions to fetch the runner results asynchronously. */
Optional<Array<RunnerFuture>> runner_futures;
/*! \brief The measure candidates. */
Optional<Array<MeasureCandidate>> measure_candidates;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("mod", &mod);
v->Visit("target", &target);
v->Visit("space_generator", &space_generator);
v->Visit("search_strategy", &search_strategy);
v->Visit("task_name", &task_name);
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
v->Visit("is_stopped", &is_stopped);
v->Visit("runner_futures", &runner_futures);
v->Visit("measure_candidates", &measure_candidates);
}

static constexpr const char* _type_key = "meta_schedule.TuneContext";
Expand All @@ -67,13 +80,15 @@ class TuneContext : public runtime::ObjectRef {
* \param mod The workload to be tuned.
* \param target The target to be tuned for.
* \param space_generator The design space generator.
* \param search_strategy The search strategy.
* \param task_name The name of the tuning task.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
*/
TVM_DLL explicit TuneContext(Optional<IRModule> mod, //
Optional<Target> target, //
Optional<SpaceGenerator> space_generator, //
Optional<SearchStrategy> search_strategy, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
The tvm.meta_schedule.database package.
The database that stores serialized tuning records and workloads
"""
from .database import Database, PyDatabase, TuningRecord
from .database import Database, PyDatabase, TuningRecord, Workload
from .json_database import JSONDatabase
6 changes: 5 additions & 1 deletion python/tvm/meta_schedule/search_strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Search Strategy"""
"""
The tvm.meta_schedule.search_strategy package.
Meta Schedule search strategy utilizes the design spaces given
to generate measure candidates.
"""

from .search_strategy import SearchStrategy, PySearchStrategy
from .replay_trace import ReplayTrace
24 changes: 24 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
The tvm.meta_schedule.task_scheduler package.
Meta Schedule task scheduler that manage the task scheduling
for measure candidates generation and measurement, then save
records to the database.
"""
from .task_scheduler import TaskScheduler, PyTaskScheduler
from .round_robin import RoundRobin
Loading