From 769ec9611610d06b22739ad0168de26ab363a03b Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 27 Sep 2021 10:57:44 -0700 Subject: [PATCH 1/3] Add docs. --- src/meta_schedule/utils.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 30294b8f91e1..132379f0f25e 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -35,6 +35,8 @@ #include #include +#include + #include "../printer/text_printer.h" #include "../support/array.h" #include "../support/base64.h" From 01e84cf0364be54e1eedec005b572dce667f59ae Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 29 Sep 2021 16:23:24 -0700 Subject: [PATCH 2/3] Add TaskScheduler. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- include/tvm/meta_schedule/runner.h | 6 +- include/tvm/meta_schedule/task_scheduler.h | 220 ++++++++++++++++++ include/tvm/meta_schedule/tune_context.h | 15 ++ python/tvm/meta_schedule/database/__init__.py | 2 +- .../meta_schedule/search_strategy/__init__.py | 6 +- .../meta_schedule/task_scheduler/__init__.py | 24 ++ .../task_scheduler/round_robin.py | 64 +++++ .../task_scheduler/task_scheduler.py | 122 ++++++++++ python/tvm/meta_schedule/tune_context.py | 11 + python/tvm/meta_schedule/utils.py | 24 +- .../task_scheduler/round_robin.cc | 71 ++++++ .../task_scheduler/task_scheduler.cc | 219 +++++++++++++++++ src/meta_schedule/tune_context.cc | 9 +- src/meta_schedule/utils.h | 3 +- .../test_meta_schedule_task_scheduler.py | 218 +++++++++++++++++ 15 files changed, 1005 insertions(+), 9 deletions(-) create mode 100644 include/tvm/meta_schedule/task_scheduler.h create mode 100644 python/tvm/meta_schedule/task_scheduler/__init__.py create mode 100644 python/tvm/meta_schedule/task_scheduler/round_robin.py create mode 100644 python/tvm/meta_schedule/task_scheduler/task_scheduler.py create mode 100644 src/meta_schedule/task_scheduler/round_robin.cc create mode 100644 src/meta_schedule/task_scheduler/task_scheduler.cc create mode 100644 tests/python/unittest/test_meta_schedule_task_scheduler.py diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index a45a4898d64a..c1451ae977d4 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -25,7 +25,7 @@ namespace tvm { namespace meta_schedule { -/*! \brief The runner's input. */ +/*! \brief Runner's input containing path of artifact, type of device and argument info. */ class RunnerInputNode : public runtime::Object { public: /*! \brief The path to the built artifact. */ @@ -61,7 +61,7 @@ class RunnerInput : public runtime::ObjectRef { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerInput, runtime::ObjectRef, RunnerInputNode); }; -/*! \brief The runner's output. */ +/*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */ class RunnerResultNode : public runtime::Object { public: /*! \brief The run time in seconds.*/ @@ -96,7 +96,7 @@ class RunnerResult : public runtime::ObjectRef { /*! * \brief A class to asynchronously fetch runner's output. * \note The API design is consistent with python's concurrent.futures.Future: - * https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future + * https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future */ class RunnerFutureNode : public runtime::Object { public: diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h new file mode 100644 index 000000000000..a2db24e31a87 --- /dev/null +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_TASK_SCHEDULER_H_ +#define TVM_META_SCHEDULE_TASK_SCHEDULER_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +/*! + * \brief The abstract interface of task schedulers. + * \note The relationship between SpaceGenerator and other classes are as follows: + ┌──────────────────────────────────────────────────────────────┐ + ┌──┴───────────────────────────────────────────────────────────┐ │ +┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ +│ ┌─────────────────────┐ │ │ │ +│ │ │ Generate │ │ │ +│ │ Space Generator ├──────────────┐ │ │ │ +│ │ │ │ │ │ │ +│ └─────────────────────┘ ▼ │ │ │ +│ Design Space │ │ │ +│ ┌─────────────────────┐ │ │ │ │ +│ Generate │ │ Pretuning │ │ │ │ +│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ +│ │ │ │ │ ├──┘ +│ │ └─────────────────────┘ ├──┘ +└────┼─────────────────────────────────────────────────────────┘ + │ + │ +┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ +│ │ ┌───────────┐ │ +│ │ Send to │ │ Send to │ +│ ▼ ┌─────────────►│ Builder ├──────────┐ │ +│ Measure Candidate │ Builder │ │ Runner │ │ +│ │ │ └───────────┘ │ │ +│ │ ┌────────────┴────────┐ │ │ +│ │ │ │ ┌───────────┐ │ │ +│ └────►│ Task Scheduler │ │ │ │ │ +│ │ │ │ Runner │◄─────────┘ │ +│ └─────────────────────┘ │ │ │ +│ ▲ └─────┬─────┘ │ +│ │ │ │ +│ └─── Runner Future ◄────┘ │ +└─────────────────────────────────────────────────────────────────────┘ +*/ +class TaskSchedulerNode : public runtime::Object { + public: + /*! \brief The tasks to be tuned */ + Array tasks; + /*! \brief The builder of the scheduler. */ + Builder builder{nullptr}; + /*! \brief The runner of the scheduler. */ + Runner runner{nullptr}; + /*! \brief The database of the scheduler. */ + Database database{nullptr}; + + /*! \brief The default desctructor. */ + virtual ~TaskSchedulerNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tasks", &tasks); + v->Visit("builder", &builder); + v->Visit("runner", &runner); + v->Visit("database", &database); + } + + /*! \brief Auto-tuning. */ + virtual void Tune(); + + /*! + * \brief Set specific task to be stopped. + * \param task_id The task id to be stopped. + */ + virtual void SetTaskStopped(int task_id); + + /*! + * \brief Check whether the task is running. + * \param task_id The task id to be checked. + * \return Whether the task is running. + */ + virtual bool IsTaskRunning(int task_id); + + /*! + * \brief Wait until the task is finished. + * \param task_id The task id to be joined. + */ + virtual void JoinRunningTask(int task_id); + + /*! + * \brief Fetch the next task id. + * \return The next task id. + */ + virtual int NextTaskId() = 0; + + static constexpr const char* _type_key = "meta_schedule.TaskScheduler"; + TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object); +}; + +/*! \brief The task scheduler with customized methods on the python-side. */ +class PyTaskSchedulerNode : public TaskSchedulerNode { + public: + /*! \brief The function type of `Tune` method. */ + using FTune = runtime::TypedPackedFunc; + + /*! + * \brief The function type of `SetTaskStopped` method. + * \param task_id The task id to be stopped. + */ + using FSetTaskStopped = runtime::TypedPackedFunc; + + /*! + * \brief The function type of `IsTaskRunning` method. + * \param task_id The task id to be checked. + * \return Whether the task is running. + */ + using FIsTaskRunning = runtime::TypedPackedFunc; + + /*! + * \brief The function type of `JoinRunningTask` method. + * \param task_id The task id to be joined. + */ + using FJoinRunningTask = runtime::TypedPackedFunc; + + /*! + * \brief The function type of `NextTaskId` method. + * \return The next task id. + */ + using FNextTaskId = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `Tune` funcion. */ + FTune f_tune; + /*! \brief The packed function to the `SetTaskStopped` function. */ + FSetTaskStopped f_set_task_stopped; + /*! \brief The packed function to the `IsTaskRunning` function. */ + FIsTaskRunning f_is_task_running; + /*! \brief The packed function to the `JoinRunningTask` function. */ + FJoinRunningTask f_join_running_task; + /*! \brief The packed function to the `NextTaskId` function. */ + FNextTaskId f_next_task_id; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_tune` is not visited + // `f_set_task_stopped` is not visited + // `f_is_task_running` is not visited + // `f_join_running_task` is not visited + // `f_next_task_id` is not visited + } + + void Tune() final { // + f_tune(); + } + + void SetTaskStopped(int task_id) final { // + f_set_task_stopped(task_id); + } + + bool IsTaskRunning(int task_id) final { // + return f_is_task_running(task_id); + } + + void JoinRunningTask(int task_id) final { // + f_join_running_task(task_id); + } + + int NextTaskId() final { // + return f_next_task_id(); + } + + static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode); +}; + +/*! + * \brief Managed reference to TaskSchedulerNode. + * \sa TaskSchedulerNode + */ +class TaskScheduler : public runtime::ObjectRef { + public: + /*! + * \brief Create a task scheduler that fetches tasks in a round-robin fashion. + * \param tasks The tasks to be tuned. + * \param builder The builder of the scheduler. + * \param runner The runner of the scheduler. + * \param database The database of the scheduler. + */ + TVM_DLL static TaskScheduler RoundRobin(Array tasks, Builder builder, Runner runner, + Database database); + TVM_DLL static TaskScheduler PyTaskScheduler( + PyTaskSchedulerNode::FTune f_tune, // + PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // + PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, // + PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, // + PyTaskSchedulerNode::FNextTaskId f_next_task_id); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_TASK_SCHEDULER_H_ diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 87a3a491c8f3..db72328c91c3 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -36,6 +36,8 @@ class TuneContextNode : public runtime::Object { Optional target; /*! \brief The design space generator. */ Optional space_generator; + /*! \brief The search strategy. */ + Optional search_strategy; /*! \brief The name of the tuning task. */ Optional task_name; /*! \brief The random state. */ @@ -43,13 +45,24 @@ class TuneContextNode : public runtime::Object { /*! \brief The number of threads to be used. */ int num_threads; + /*! \brief Whether the tuning task has been stopped or finished. */ + bool is_stopped; + /*! \brief Packed functions to fetch the runner results asynchronously. */ + Optional> runner_futures; + /*! \brief The measure candidates. */ + Optional> measure_candidates; + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("mod", &mod); v->Visit("target", &target); v->Visit("space_generator", &space_generator); + v->Visit("search_strategy", &search_strategy); v->Visit("task_name", &task_name); v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); + v->Visit("is_stopped", &is_stopped); + v->Visit("runner_futures", &runner_futures); + v->Visit("measure_candidates", &measure_candidates); } static constexpr const char* _type_key = "meta_schedule.TuneContext"; @@ -67,6 +80,7 @@ class TuneContext : public runtime::ObjectRef { * \param mod The workload to be tuned. * \param target The target to be tuned for. * \param space_generator The design space generator. + * \param search_strategy The search strategy. * \param task_name The name of the tuning task. * \param rand_state The random state. * \param num_threads The number of threads to be used. @@ -74,6 +88,7 @@ class TuneContext : public runtime::ObjectRef { TVM_DLL explicit TuneContext(Optional mod, // Optional target, // Optional space_generator, // + Optional search_strategy, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads); diff --git a/python/tvm/meta_schedule/database/__init__.py b/python/tvm/meta_schedule/database/__init__.py index dcd430d39407..320647b0e31b 100644 --- a/python/tvm/meta_schedule/database/__init__.py +++ b/python/tvm/meta_schedule/database/__init__.py @@ -18,5 +18,5 @@ The tvm.meta_schedule.database package. The database that stores serialized tuning records and workloads """ -from .database import Database, PyDatabase, TuningRecord +from .database import Database, PyDatabase, TuningRecord, Workload from .json_database import JSONDatabase diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py index 40f21da0b2d1..609baa267786 100644 --- a/python/tvm/meta_schedule/search_strategy/__init__.py +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -14,7 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Search Strategy""" +""" +The tvm.meta_schedule.search_strategy package. +Meta Schedule search strategy utilizes the design spaces given +to generate measure candidates. +""" from .search_strategy import SearchStrategy, PySearchStrategy from .replay_trace import ReplayTrace diff --git a/python/tvm/meta_schedule/task_scheduler/__init__.py b/python/tvm/meta_schedule/task_scheduler/__init__.py new file mode 100644 index 000000000000..dbfe962d9966 --- /dev/null +++ b/python/tvm/meta_schedule/task_scheduler/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.task_scheduler package. +Meta Schedule task scheduler that manage the task scheduling +for measure candidates generation and measurement, then save +records to the database. +""" +from .task_scheduler import TaskScheduler, PyTaskScheduler +from .round_robin import RoundRobin diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py new file mode 100644 index 000000000000..391011b4f53f --- /dev/null +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Round Robin Task Scheduler""" + +from typing import List, TYPE_CHECKING + +from tvm._ffi import register_object + +from ..builder import Builder +from ..runner import Runner +from ..database import Database +from .task_scheduler import TaskScheduler + +from .. import _ffi_api + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.RoundRobin") +class RoundRobin(TaskScheduler): + """Round Robin Task Scheduler""" + + def __init__( + self, + tasks: List["TuneContext"], + builder: Builder, + runner: Runner, + database: Database, + ) -> None: + """Constructor. + + Parameters + ---------- + tasks : List[TuneContext] + List of tasks to schedule. + builder : Builder + The builder. + runner : Runner + The runner. + database : Database + The database. + """ + self.__init_handle_by_constructor__( + _ffi_api.TaskSchedulerRoundRobin, # type: ignore # pylint: disable=no-member + tasks, + builder, + runner, + database, + ) diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py new file mode 100644 index 000000000000..b8dcfd9e7a2d --- /dev/null +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Auto-tuning Task Scheduler""" +from tvm._ffi import register_object +from tvm.runtime import Object + +from .. import _ffi_api + + +@register_object("meta_schedule.TaskScheduler") +class TaskScheduler(Object): + """The abstract task scheduler interface.""" + + def tune(self) -> None: + """Auto-tuning.""" + _ffi_api.TaskSchedulerTune(self) # pylint: disable=no-member + + def _set_task_stopped(self, task_id: int) -> None: + """Set specific task to be stopped. + + Parameters + ---------- + task_id : int + The task id to be stopped. + """ + _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # pylint: disable=no-member + + def _is_task_running(self, task_id: int) -> bool: + """Check whether the task is running. + + Parameters + ---------- + task_id : int + The task id to be checked. + + Returns + ------- + bool + Whether the task is running. + """ + return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # pylint: disable=no-member + + def _join_running_task(self, task_id: int) -> None: + """Wait until the task is finished. + + Parameters + ---------- + task_id : int + The task id to be joined. + """ + _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # pylint: disable=no-member + + def _next_task_id(self) -> int: + """Fetch the next task id. + + Returns + ------- + int + The next task id. + """ + return _ffi_api.TaskSchedulerNextTaskId(self) # pylint: disable=no-member + + +@register_object("meta_schedule.PyTaskScheduler") +class PyTaskScheduler(TaskScheduler): + """An abstract task scheduler with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + def f_tune() -> None: + self.tune() + + def f_set_task_stopped(task_id: int) -> None: + self._set_task_stopped(task_id) + + def f_is_task_running(task_id: int) -> bool: + return self._is_task_running(task_id) + + def f_join_running_task(task_id: int) -> None: + self._join_running_task(task_id) + + def f_next_task_id() -> int: + return self._next_task_id() + + self.__init_handle_by_constructor__( + _ffi_api.TaskSchedulerPyTaskScheduler, # pylint: disable=no-member + f_tune, + f_set_task_stopped, + f_is_task_running, + f_join_running_task, + f_next_task_id, + ) + + def tune(self) -> None: + raise NotImplementedError() + + def _set_task_stopped(self, task_id: int) -> None: + _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # pylint: disable=no-member + + def _is_task_running(self, task_id: int) -> bool: + return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # pylint: disable=no-member + + def _join_running_task(self, task_id: int) -> None: + _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # pylint: disable=no-member + + def _next_task_id(self) -> int: + return _ffi_api.TaskSchedulerNextTaskId(self) # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 9c41b4d575da..0f3cfac1a85f 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: from .space_generator import SpaceGenerator + from .search_strategy import SearchStrategy @register_object("meta_schedule.TuneContext") @@ -45,6 +46,10 @@ class TuneContext(Object): The workload to be optimized. target : Optional[Target] = None The target to be optimized for. + space_generator : Optional[SpaceGenerator] = None + The design space generator. + search_strategy : Optional[SearchStrategy] = None + The search strategy. task_name : Optional[str] = None The name of the tuning task. rand_state : int = -1 @@ -63,6 +68,8 @@ class TuneContext(Object): mod: Optional[IRModule] target: Optional[Target] + space_generator: "SpaceGenerator" + search_strategy: "SearchStrategy" task_name: Optional[str] rand_state: int num_threads: int @@ -72,6 +79,7 @@ def __init__( mod: Optional[IRModule] = None, target: Optional[Target] = None, space_generator: Optional["SpaceGenerator"] = None, + search_strategy: Optional["SearchStrategy"] = None, task_name: Optional[str] = None, rand_state: int = -1, num_threads: Optional[int] = None, @@ -86,6 +94,8 @@ def __init__( The target to be optimized for. space_generator : Optional[SpaceGenerator] = None The design space generator. + search_strategy : Optional[SearchStrategy] = None + The search strategy. task_name : Optional[str] = None The name of the tuning task. rand_state : int = -1 @@ -102,6 +112,7 @@ def __init__( mod, target, space_generator, + search_strategy, task_name, rand_state, num_threads, diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 5f536994a9fd..bf2ef17fb308 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -21,9 +21,10 @@ from typing import Any, Callable, List, Optional, Union import psutil +import tvm from tvm._ffi import get_global_func, register_func from tvm.error import TVMError -from tvm.ir import Array, Map +from tvm.ir import Array, Map, IRModule from tvm.rpc import RPCSession from tvm.runtime import PackedFunc, String from tvm.tir import FloatImm, IntImm @@ -183,3 +184,24 @@ def batch_json_str2obj(json_strs: List[str]) -> List[Any]: for json_str in map(str.strip, json_strs) if json_str and (not json_str.startswith("#")) and (not json_str.startswith("//")) ] + + +def structural_hash(mod: IRModule) -> str: + """Get the structural hash of a module. + + Parameters + ---------- + mod : IRModule + The module to be hashed. + + Returns + ------- + result : str + The structural hash of the module. + """ + shash = tvm.ir.structural_hash(mod) + if shash < 0: + # Workaround because `structural_hash` returns a size_t, i.e., unsigned integer + # but ffi can't handle unsigned integers properly so it's parsed into a negative number + shash += 1 << 64 + return str(shash) diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc new file mode 100644 index 000000000000..a529f2354d87 --- /dev/null +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief The round-robin style task scheduler. */ +class RoundRobinNode final : public TaskSchedulerNode { + public: + /*! \brief The current task id processed. */ + int task_id = -1; + + void VisitAttrs(tvm::AttrVisitor* v) { + TaskSchedulerNode::VisitAttrs(v); + v->Visit("task_id", &task_id); + } + + static constexpr const char* _type_key = "meta_schedule.RoundRobin"; + TVM_DECLARE_FINAL_OBJECT_INFO(RoundRobinNode, TaskSchedulerNode); + + protected: + int NextTaskId() final { + int n_tasks = this->tasks.size(); + for (int i = 0; i < n_tasks; ++i) { + task_id = (task_id + 1) % n_tasks; + TuneContext task = tasks[task_id]; + if (!task->is_stopped) { + if (IsTaskRunning(task_id)) { + JoinRunningTask(task_id); + } + return task_id; + } + } + return -1; + } +}; + +TaskScheduler TaskScheduler::RoundRobin(Array tasks, Builder builder, Runner runner, + Database database) { + ObjectPtr n = make_object(); + n->tasks = tasks; + n->builder = builder; + n->runner = runner; + n->database = database; + n->task_id = -1; + return TaskScheduler(n); +} + +TVM_REGISTER_NODE_TYPE(RoundRobinNode); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerRoundRobin") + .set_body_typed(TaskScheduler::RoundRobin); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc new file mode 100644 index 000000000000..cf0af3d55fe4 --- /dev/null +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! + * \brief Send the measure candidates to builder. + * \param builder The builder to send the candidates to. + * \param context The tuning context. + * \param candidates The measure candidates. + * \return An array of the builder results. + */ +Array SendToBuilder(const Builder& builder, // + const TuneContext& context, + const Array& candidates) { + Target target = context->target.value(); + Array inputs; + inputs.reserve(candidates.size()); + for (const MeasureCandidate& candidate : candidates) { + inputs.push_back(BuilderInput(candidate->sch->mod(), target)); + } + return builder->Build(inputs); +} + +/*! + * \brief Send the built measure candidates to runner. + * \param runner The runner to send the candidates to. + * \param context The tuning context. + * \param candidates The mesure candidates. + * \param builder_results The builder results. + * \return An array of the runner results. + */ +Array SendToRunner(const Runner& runner, // + const TuneContext& context, + const Array& candidates, + const Array& builder_results) { + Target target = context->target.value(); + ICHECK_EQ(candidates.size(), builder_results.size()); + int n = candidates.size(); + int n_build_errors = 0; + Array inputs; + inputs.reserve(n); + for (int i = 0; i < n; ++i) { + const MeasureCandidate& candidate = candidates[i]; + const BuilderResult& builder_result = builder_results[i]; + if (builder_result->error_msg.defined()) { + ++n_build_errors; + continue; + } + inputs.push_back(RunnerInput(/*artifact_path=*/builder_result->artifact_path.value(), + /*device_type=*/target->kind->name, + /*args_info=*/candidate->args_info)); + } + Array futures = runner->Run(inputs); + if (n_build_errors == 0) { + return futures; + } + Array results; + results.reserve(n); + for (int i = 0, j = 0; i < n; ++i) { + const BuilderResult& builder_result = builder_results[i]; + if (builder_result->error_msg.defined()) { + results.push_back(RunnerFuture( + /*f_done=*/[]() -> bool { return true; }, + /*f_result=*/ + [msg = builder_result->error_msg]() -> RunnerResult { + return RunnerResult(NullOpt, msg); + })); + } else { + results.push_back(futures[j++]); + } + } + return results; +} + +void TaskSchedulerNode::Tune() { + for (const TuneContext& task : this->tasks) { + CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; + CHECK(task->space_generator.defined()) + << "ValueError: Require `context.space_generator`, but it is not defined"; + CHECK(task->search_strategy.defined()) + << "ValueError: Require `context.search_strategy`, but it is not defined"; + IRModule mod = task->mod.value(); + SpaceGenerator space = task->space_generator.value(); + SearchStrategy strategy = task->search_strategy.value(); + space->InitializeWithTuneContext(task); + strategy->InitializeWithTuneContext(task); + strategy->PreTuning(space->GenerateDesignSpace(mod)); + } + + int running_tasks = tasks.size(); + while (running_tasks > 0) { + for (int task_id; (task_id = NextTaskId()) != -1;) { + TuneContext task = tasks[task_id]; + ICHECK(!task->is_stopped); + ICHECK(!task->runner_futures.defined()); + SearchStrategy strategy = task->search_strategy.value(); + if (task->measure_candidates = strategy->GenerateMeasureCandidates()) { + Array builder_results = + SendToBuilder(this->builder, task, task->measure_candidates.value()); + task->runner_futures = + SendToRunner(this->runner, task, task->measure_candidates.value(), builder_results); + } else { + SetTaskStopped(task_id); + --running_tasks; + } + } + int n_tasks = this->tasks.size(); + for (int task_id = 0; task_id < n_tasks; ++task_id) + if (IsTaskRunning(task_id)) { + TuneContext task = tasks[task_id]; + this->JoinRunningTask(task_id); + task->search_strategy.value()->PostTuning(); + } + } +} + +void TaskSchedulerNode::SetTaskStopped(int task_id) { + TuneContext task = tasks[task_id]; + ICHECK(!task->is_stopped); + task->is_stopped = true; +} + +bool TaskSchedulerNode::IsTaskRunning(int task_id) { + TuneContext task = tasks[task_id]; + if (task->is_stopped || !task->runner_futures.defined()) { + return false; + } + for (const RunnerFuture future : task->runner_futures.value()) { + if (!future->Done()) { + return true; + } + } + this->JoinRunningTask(task_id); + return false; +} + +void TaskSchedulerNode::JoinRunningTask(int task_id) { + TuneContext task = tasks[task_id]; + ICHECK(task->runner_futures.defined()); + Array futures = task->runner_futures.value(); + int n = futures.size(); + Array results; + results.reserve(n); + for (const RunnerFuture future : task->runner_futures.value()) { + results.push_back(future->Result()); + } + task->search_strategy.value()->NotifyRunnerResults(results); + task->runner_futures = NullOpt; + // Add to database + ICHECK(task->measure_candidates.defined()); + ICHECK(results.size() == task->measure_candidates.value().size()); + int index = 0; + for (const RunnerResult& result : results) { + if (!result->error_msg.defined() && result->run_secs.defined()) { + Optional trace = task->measure_candidates.value()[index]->sch->trace(); + ICHECK(trace.defined()); + this->database->CommitTuningRecord(TuningRecord( + /*trace=*/trace.value(), + /*run_secs=*/result->run_secs.value(), + /*workload=*/this->database->CommitWorkload(task->mod.value()), + /*target=*/task->target.value(), + /*args_info=*/task->measure_candidates.value()[index]->args_info)); + } + index++; + } +} + +TaskScheduler TaskScheduler::PyTaskScheduler( + PyTaskSchedulerNode::FTune f_tune, // + PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // + PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, // + PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, // + PyTaskSchedulerNode::FNextTaskId f_next_task_id) { + ObjectPtr n = make_object(); + n->f_tune = f_tune; + n->f_set_task_stopped = f_set_task_stopped; + n->f_is_task_running = f_is_task_running; + n->f_join_running_task = f_join_running_task; + n->f_next_task_id = f_next_task_id; + return TaskScheduler(n); +} + +TVM_REGISTER_OBJECT_TYPE(TaskSchedulerNode); +TVM_REGISTER_NODE_TYPE(PyTaskSchedulerNode); +TVM_REGISTER_GLOBAL("tvm.task.TaskSchedulerPyTaskScheduler") + .set_body_typed(TaskScheduler::PyTaskScheduler); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerSetTaskStopped") + .set_body_method(&TaskSchedulerNode::SetTaskStopped); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerIsTaskRunning") + .set_body_method(&TaskSchedulerNode::IsTaskRunning); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune") + .set_body_method(&TaskSchedulerNode::Tune); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerJoinRunningTask") + .set_body_method(&TaskSchedulerNode::JoinRunningTask); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerNextTaskId") + .set_body_method(&TaskSchedulerNode::NextTaskId); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index ad82b6f514a2..9fc9272e33ac 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -37,6 +37,7 @@ namespace meta_schedule { TuneContext::TuneContext(Optional mod, // Optional target, // Optional space_generator, // + Optional search_strategy, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) { @@ -44,12 +45,16 @@ TuneContext::TuneContext(Optional mod, n->mod = mod; n->target = target; n->space_generator = space_generator; + n->search_strategy = search_strategy; n->task_name = task_name; if (rand_state == -1) { rand_state = std::random_device()(); } support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state); n->num_threads = num_threads; + n->is_stopped = false; + n->runner_futures = NullOpt; + n->measure_candidates = NullOpt; data_ = std::move(n); } @@ -59,10 +64,12 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") .set_body_typed([](Optional mod, // Optional target, // Optional space_generator, // + Optional search_strategy, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) -> TuneContext { - return TuneContext(mod, target, space_generator, task_name, rand_state, num_threads); + return TuneContext(mod, target, space_generator, search_strategy, task_name, rand_state, + num_threads); }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 132379f0f25e..83e65a5ced44 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -35,8 +36,6 @@ #include #include -#include - #include "../printer/text_printer.h" #include "../support/array.h" #include "../support/base64.h" diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py new file mode 100644 index 000000000000..bdd504cfccda --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule Task Scheduler """ + +from typing import List + +import sys +import random + +import pytest + +import tvm +from tvm import tir +from tvm.script import ty +from tvm.ir import IRModule +from tvm.tir import Schedule +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.space_generator import ScheduleFn +from tvm.meta_schedule.search_strategy import ReplayTrace +from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult +from tvm.meta_schedule.runner import PyRunner, RunnerInput, RunnerFuture, RunnerResult +from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload +from tvm.meta_schedule.task_scheduler import RoundRobin +from tvm.meta_schedule.utils import structural_hash + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring + + +@tvm.script.tir +class MatmulModule: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, (1024, 1024), "float32") + B = tir.match_buffer(b, (1024, 1024), "float32") + C = tir.match_buffer(c, (1024, 1024), "float32") + with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.tir +class MatmulReluModule: + def main(a: ty.handle, b: ty.handle, d: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, (1024, 1024), "float32") + B = tir.match_buffer(b, (1024, 1024), "float32") + D = tir.match_buffer(d, (1024, 1024), "float32") + C = tir.alloc_buffer((1024, 1024), "float32") + with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + with tir.block([1024, 1024], "relu") as [vi, vj]: + D[vi, vj] = tir.max(C[vi, vj], 0.0) + + +@tvm.script.tir +class BatchMatmulModule: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, [16, 128, 128]) + B = tir.match_buffer(b, [16, 128, 128]) + C = tir.match_buffer(c, [16, 128, 128]) + with tir.block([16, 128, 128, tir.reduce_axis(0, 128)], "matmul") as [vn, vi, vj, vk]: + with tir.init(): + C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + + +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks + + +def _schedule_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k = sch.get_loops(block=block) + # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + + +def _schedule_batch_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k, t = sch.get_loops(block=block) + # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 2, 2, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[2, 4, 64, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + t_0, t_1 = sch.split(loop=t, factors=[2, 512]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, t_0, t_1) + + +class DummyRunnerFuture(RunnerFuture): + def done(self) -> bool: + return True + + def result(self) -> RunnerResult: + return RunnerResult([random.uniform(5, 30) for _ in range(random.randint(1, 10))], None) + + +class DummyBuilder(PyBuilder): + def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: + return [BuilderResult("test_path", None) for _ in build_inputs] + + +class DummyRunner(PyRunner): + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + return [DummyRunnerFuture() for _ in runner_inputs] + + +class DummyDatabase(PyDatabase): + def __init__(self): + super().__init__() + self.records = [] + self.workload_reg = [] + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.records.append(record) + + def commit_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return workload + workload = Workload(mod) + self.workload_reg.append(workload) + return workload + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + return list( + filter( + lambda x: x.workload == workload, + sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), + ) + )[: int(top_k)] + + def __len__(self) -> int: + return len(self.records) + + def print_results(self) -> None: + print("\n".join([str(r) for r in self.records])) + + +def test_meta_schedule_task_scheduler_single(): + num_trials_per_iter = 3 + num_trials_total = 10 + sch_fn = ScheduleFn(sch_fn=_schedule_matmul) + replay = ReplayTrace(num_trials_per_iter, num_trials_total) + task = TuneContext( + MatmulModule(), + target=tvm.target.Target("llvm"), + space_generator=sch_fn, + search_strategy=replay, + task_name="Test", + rand_state=42, + ) + database = DummyDatabase() + round_robin = RoundRobin([task], DummyBuilder(), DummyRunner(), database) + round_robin.tune() + assert len(database) == num_trials_total + + +def test_meta_schedule_task_scheduler_multiple(): + num_trials_per_iter = 6 + num_trials_total = 101 + tasks = [ + TuneContext( + MatmulModule(), + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="Matmul", + rand_state=42, + ), + TuneContext( + MatmulReluModule(), + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="MatmulRelu", + rand_state=0xDEADBEEF, + ), + TuneContext( + BatchMatmulModule(), + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="BatchMatmul", + rand_state=0x114514, + ), + ] + database = DummyDatabase() + round_robin = RoundRobin(tasks, DummyBuilder(), DummyRunner(), database) + round_robin.tune() + assert len(database) == num_trials_total * len(tasks) + print(database.workload_reg) + for task in tasks: + assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From cd9981f0b733e178230a045620bfb5122f4b3a52 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 29 Sep 2021 18:32:19 -0700 Subject: [PATCH 3/3] Retrigger CI after hotfix.