diff --git a/CMakeLists.txt b/CMakeLists.txt index d7faa8a4b666..9f5c5084d6c3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,6 +185,7 @@ assign_source_group("Include" ${GROUP_INCLUDE}) # Source file lists file(GLOB_RECURSE COMPILER_SRCS + src/auto_schedule/*.cc src/node/*.cc src/ir/*.cc src/arith/*.cc diff --git a/python/tvm/auto_schedule/__init__.py b/python/tvm/auto_schedule/__init__.py new file mode 100644 index 000000000000..90bec8665cef --- /dev/null +++ b/python/tvm/auto_schedule/__init__.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +""" Namespace for TVM Auto-scheduler. """ + +from . import compute_dag +from . import measure +from . import measure_record +from . import loop_state +from . import utils +from . import workload_registry + +# Shortcut +from .compute_dag import ComputeDAG +from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \ + auto_schedule, EmptyPolicy +from .measure import MeasureInput, LocalBuilder, LocalRunner +from .measure_record import RecordToFile, RecordReader, load_best, \ + load_records, save_records +from .workload_registry import register_workload, make_workload_key diff --git a/python/tvm/auto_schedule/_ffi_api.py b/python/tvm/auto_schedule/_ffi_api.py new file mode 100644 index 000000000000..9d2b9865ae95 --- /dev/null +++ b/python/tvm/auto_schedule/_ffi_api.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Register FFI APIs from C++ for the namespace tvm.auto_schedule. """ +import tvm._ffi + + +tvm._ffi._init_api("auto_schedule", __name__) diff --git a/python/tvm/auto_schedule/auto_schedule.py b/python/tvm/auto_schedule/auto_schedule.py new file mode 100644 index 000000000000..ffbfc3c914ff --- /dev/null +++ b/python/tvm/auto_schedule/auto_schedule.py @@ -0,0 +1,194 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +User interface for TVM Auto-scheduler. + +The basic schedule search process for TVM Auto-scheduler is designed to be: +`Program sampling` -> `Performance Tuning`. + +In `Program sampling`, we use some predefined precise or heuristic rules to generate several +initial schedules. Based on these initial starting points, we perform `Performance Tuning` which +uses cost model based evolutionary search to select schedules with the best performance. + +Candidate schedules are measured against the specific hardware target. +""" + +import tvm._ffi +from tvm.runtime import Object +from .measure import LocalBuilder, LocalRunner +from . import _ffi_api + + +@tvm._ffi.register_object("auto_schedule.HardwareParams") +class HardwareParams(Object): + """ The parameters of target hardware used to guide the search policy + + TODO(jcf94): This is considered to be merged with the new Target specification: + https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844 + + Parameters + ---------- + num_cores : int + The number of device cores. + vector_unit_bytes : int + The width of vector units in bytes. + cache_line_bytes : int + The size of cache line in bytes. + """ + def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes): + self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores, + vector_unit_bytes, cache_line_bytes) + + +@tvm._ffi.register_object("auto_schedule.SearchTask") +class SearchTask(Object): + """ The computation information and hardware parameters for a specific schedule search task. + + Parameters + ---------- + dag : ComputeDAG + The ComputeDAG for the corresponding compute declaration. + workload_key : str + The workload key for the corresponding compute declaration. + target : tvm.target.Target + The target device of this search task. + target_host : Optional[tvm.target.Target] + The target host device of this search task. + hardware_params : Optional[HardwareParams] + Hardware parameters used in this search task. + """ + def __init__(self, dag, workload_key, target, target_host=None, + hardware_params=None): + self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag, + workload_key, target, target_host, + hardware_params) + + +@tvm._ffi.register_object("auto_schedule.SearchPolicy") +class SearchPolicy(Object): + """ The base class of search policies. """ + + +@tvm._ffi.register_object("auto_schedule.EmptyPolicy") +class EmptyPolicy(SearchPolicy): + """ This is an example empty search policy which will always generate + the init state of ComputeDAG. + """ + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy) + + +@tvm._ffi.register_object("auto_schedule.TuningOptions") +class TuningOptions(Object): + """ This controls the options of performance tuning. + + Parameters + ---------- + num_measure_trials: int = 0 + The number of measurement trials. + The search policy measures `num_measure_trials` schedules in total and returns the best one + among them. + With `num_measure_trials` == 0, the policy will do the schedule search but won't involve + measurement. This can be used to get a runnable schedule quickly without auto-tuning. + early_stopping: Optional[int] + Stop the tuning early if getting no improvement after n measurements. + num_measures_per_round: int = 64 + The number of schedules to be measured at each search round. + The whole schedule search process will try a total number of `num_measure_trials` in several + rounds. + verbose: int = 1 + Verbosity level. 0 for silent, 1 to output information during schedule search. + builder: Union[ProgramBuilder, str] = 'local' + ProgramBuilder which builds the program. + runner: Union[ProgramRunner, str] = 'local' + ProgramRunner which runs the program and measures time costs. + measure_callbacks: Optional[List[MeasureCallback]] + Callback functions called after each measurement. + Candidates: + - auto_schedule.RecordToFile + pre_search_callbacks: Optional[List[SearchCallback]] + Callback functions called before the search process. + Candidates: + - auto_schedule.PreloadMeasuredStates + - auto_schedule.PreloadCustomSketchRule + TODO(jcf94): Add these implementation in later PRs. + """ + def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_round=64, + verbose=1, builder='local', runner='local', measure_callbacks=None, + pre_search_callbacks=None): + if isinstance(builder, str): + if builder == 'local': + builder = LocalBuilder() + else: + raise ValueError("Invalid builder: " + builder) + elif not isinstance(builder, tvm.auto_schedule.measure.ProgramBuilder): + raise ValueError("Invalid builder: " + builder + + " . TuningOptions expects a ProgramBuilder or string.") + + if isinstance(runner, str): + if runner == 'local': + runner = LocalRunner() + else: + raise ValueError("Invalid runner: " + runner) + elif not isinstance(runner, tvm.auto_schedule.measure.ProgramRunner): + raise ValueError("Invalid runner: " + runner + + " . TuningOptions expects a ProgramRunner or string.") + + self.__init_handle_by_constructor__( + _ffi_api.TuningOptions, num_measure_trials, early_stopping if early_stopping else -1, + num_measures_per_round, verbose, builder, runner, measure_callbacks, + pre_search_callbacks) + + +def auto_schedule(task, search_policy='default', tuning_options=None): + """ Do auto scheduling for a computation declaration. + + The task parameter can be a `string` as workload_key, or directly + passing a `SearchTask` as input. + + Parameters + ---------- + task : SearchTask + The SearchTask for the computation declaration. + search_policy : Union[SearchPolicy, str] = 'default' + The search policy to be used for schedule search. + tuning_options : Optional[TuningOptions] + Tuning and measurement options. + + Returns + ------- + A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`. + """ + if not isinstance(task, SearchTask): + raise ValueError("Invalid task: " + task + + " . `auto_schedule.auto_schedule` expects a SearchTask.") + + if isinstance(search_policy, str): + if search_policy == 'default': + # TODO(jcf94): This is an example policy for minimum system, will be upgrated to + # formal search policy later. + search_policy = EmptyPolicy() + else: + raise ValueError("Invalid search policy: " + search_policy) + elif not isinstance(search_policy, SearchPolicy): + raise ValueError("Invalid search policy: " + search_policy + + " . `auto_schedule.auto_schedule` expects a SearchPolicy or a string.") + + sch, tensors = _ffi_api.AutoSchedule(task, search_policy, + tuning_options if tuning_options else TuningOptions()) + return sch, tensors diff --git a/python/tvm/auto_schedule/compute_dag.py b/python/tvm/auto_schedule/compute_dag.py new file mode 100644 index 000000000000..a4738a933b3e --- /dev/null +++ b/python/tvm/auto_schedule/compute_dag.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" The TVM Auto-scheduler computational graph and related program analyses. """ + +import hashlib + +import tvm._ffi +from tvm.runtime import Object +from tvm.te import PlaceholderOp, ComputeOp + +from .loop_state import State, StateObject +from .utils import get_const_tuple +from .workload_registry import workload_key_to_tensors + +from . import _ffi_api + + +@tvm._ffi.register_object("auto_schedule.ComputeDAG") +class ComputeDAG(Object): + """ + The TVM Auto-scheduler computational graph and related program analyses. + + We convert a compute declaration described by `tvm.compute` (could be a single operator or a + subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration, + a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the + total float operation count, consumer/producer relations of each operation stage, whether an + operation stage should be tiled/compute inlined ...). These analyses can help the search policy + to make decisions during search process. + ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and + TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing + `LoopState` with extra information got from TVM schedule ...). + + Parameters + ---------- + compute : Union[List[Tensor], str] + `Tensor`s or workload key for a compute declaration. + """ + def __init__(self, compute): + if isinstance(compute, str): + compute = workload_key_to_tensors(compute) + elif isinstance(compute, list): + for item in compute: + if not isinstance(item, tvm.te.Tensor): + raise ValueError("The input of ComputeDAG should be a list of Tensor") + else: + raise ValueError("Invalid compute: " + compute + + " . ComputeDAG expects a string or list of Tensor") + self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute) + + def get_init_state(self): + """ Get the init state of this ComputeDAG. + + Returns + ------- + state : State + The initial State without any transform steps. + """ + return State(self.init_state, self) + + def apply_steps_from_state(self, state): + """ + Apply the history transform steps from a State to get a TVM schedule. + + Parameters + ---------- + state : Union[State, StateObject] + The state from which we get transform steps. + + Returns + ------- + A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`. + """ + state_obj = state if isinstance(state, StateObject) else state.state_object + return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj) + + def print_python_code_from_state(self, state): + """ + Print transform steps in the history of a State as TVM's python schedule primitive. + + This is used to print transformation steps for debugging. + Use `apply_steps_from_state` if you want to get a schedule for code generation. + + Parameters + ---------- + state : Union[State, StateObject] + The state from which we get transform steps. + + Returns + ------- + str : Str + The Python schedule code. + """ + state_obj = state if isinstance(state, StateObject) else state.state_object + return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state_obj) + + def infer_bound_from_state(self, state): + """ + Infer and fill the bound of all iterators of a state. + + The states may lose complete bound information after some transform steps + (e.g., compute_at). + We can call this function to infer and fill all the bound information. + This function calls TVM InferBound pass internally to get the bound. + The returned state of this function is guaranteed to have complete iterator extent + information. + + Parameters + ---------- + state : Union[State, StateObject] + The state from which we get transform steps. + + Returns + ------- + state : State + The State with complete bound information. + """ + state_obj = state if isinstance(state, StateObject) else state.state_object + return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) + + def __hash__(self): + # TODO(merrymercy): Implement this more carefully and move this to c++ as a member function + # of ComputeDAG + str_key = '' + for op in self.ops: + t = op.output(0) + if isinstance(op, PlaceholderOp): + str_key += 'placeholder,' + str_key += str(get_const_tuple(t.shape)) + ',' + str_key += t.dtype + ';' + elif isinstance(op, ComputeOp): + str_key += str(t.op.body) + ',' + str_key += str(get_const_tuple(t.shape)) + ',' + str_key += t.dtype + ';' + else: + raise ValueError("Invalid op: " + op) + + str_key = str_key.encode(encoding='utf-8') + return hashlib.md5(str_key).hexdigest() diff --git a/python/tvm/auto_schedule/loop_state.py b/python/tvm/auto_schedule/loop_state.py new file mode 100644 index 000000000000..7b8804c8be60 --- /dev/null +++ b/python/tvm/auto_schedule/loop_state.py @@ -0,0 +1,208 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import + +""" +The definition of the "state" in search. + +Each LoopState corresponds to a schedule for its ComputeDAG. +A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to +construct the loop structure. +The loop structure keeps a preview of how the schedule will finally look like after lowering the +current state (e.g. number of iterators, the extent of each iterator, the compute_at locations ...). +During the schedule search process, the loop structure can provide search policy with necessary +information on how to manipulate the current state. +The transform history is a sequence of `TransformStep` which will finally be mapped to TVM schedule +primitives. The steps can also be used for the serialization of a state. + +The LoopState can be seen as a lightweight loop structure IR specifically for schedule search. +We don't use the existing TVM IR but to extend a new structure on it is because: +1. We want fast incremental change to the loop structures. The search policy needs to get the +immediate loop structures update rather than after TVM lowering; +2. We want serializable transform history for replay, backtracking, and mutation; +3. We may create some macro schedule primitives that represent the combination of several +TVM schedule primitives. + +When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives. +Since we share a lot of common objects during search, the transformation is implemented in +copy on write style. All objects are immutable, which is similar to TVM IR. +""" + +import tvm._ffi +from tvm.te.tensor import Operation, Tensor +from tvm.runtime import Object +from . import _ffi_api + + +@tvm._ffi.register_object("auto_schedule.Iterator") +class Iterator(Object): + """ A loop iterator structure. """ + + +@tvm._ffi.register_object("auto_schedule.Stage") +class Stage(Object): + """ A stage in the compute declaration. Similar to tvm.te.schedule.Stage. """ + + +@tvm._ffi.register_object("auto_schedule.State") +class StateObject(Object): + """ The internal State object """ + def __eq__(self, other): + return _ffi_api.StateEqual(self, other) + + +class State: + """ + A state in the search process. It consists of the current loop structure + and a list of transformation steps used to construct it. + + Each State corresponds to a specific schedule for its ComputeDAG. + + Parameters + ---------- + state_object : StateObject + The StateObject corresponding to C++ internal State object. + dag : ComputeDAG + The original ComputeDAG of this State. + + Notes + ----- + This is a wrapper class of StateObject to deal with copy-on-write property + """ + def __init__(self, state_object, dag): + self.state_object = state_object + self.compute_dag = dag + + self.stage_id_map = {} # A dict maps operation to stage id + self._update_stage_id_map() + + @property + def stages(self): + """ + Returns + ------- + stages : List[Stage] + """ + return self.state_object.stages + + @property + def stage_ops(self): + """ + Returns + ------- + ops: List[Operation] + """ + return [stage.op for stage in self.stages] + + def reorder(self, stage, order): + """ Schedule primitive corresponds to te.reorder. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be reordered, can be a Stage order index, Stage operation or stage + output tensor. + order : List[Iterator] + Iterators in the expected order. + """ + stage_id = self._resolve_stage_id(stage) + + self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) + + def split(self, stage, iterator, lengths, inner_to_outer=True): + """ Schedule primitive corresponds to te.split. + + This API supports multiple split factors. (e.g. with 2 split factors, the original iterator + will be split to 3 parts, use `inner_to_outer` to control the split order) + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be split, can be a Stage order index, Stage operation or stage + output tensor. + iterator : Iterator + The iterator to be split. + lengths: List[int] + The multiple split factors. Can be None to be filled by search policy. + inner_to_outer: boolean = True + Whether the factor go from inner to outer, or from outer to inner. + + Returns + ------- + res_its : List[Iterator] + The splitted new Iterators + """ + stage_id = self._resolve_stage_id(stage) + + self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths, + inner_to_outer) + return res + + def fuse(self, stage, iters): + """ Schedule primitive corresponds to te.fuse. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be fused, can be a Stage order index, Stage operation or stage + output tensor. + iters : List[Iterator] + The iterators to be fused + + Returns + ------- + res_it : Iterator + The fused Iterator + """ + stage_id = self._resolve_stage_id(stage) + + self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) + return res + + def copy(self): + """ Do deep copy of this State. """ + state = State(self.state_object, self.compute_dag) + state.stage_id_map = self.stage_id_map.copy() + return state + + def _resolve_stage_id(self, stage_id): + if isinstance(stage_id, Operation): + return self.stage_id_map[stage_id] + if isinstance(stage_id, Tensor): + return self.stage_id_map[stage_id.op] + if isinstance(stage_id, int): + return stage_id + raise ValueError("Invalid stage: " + stage_id + + " . Expect to be a int, Operation or Tensor") + + def _update_stage_id_map(self): + for index, stage in enumerate(self.stages): + self.stage_id_map[stage.op] = index + + def __getitem__(self, key): + if isinstance(key, Tensor): + key = key.op + if isinstance(key, Operation): + return self.stages[self.stage_id_map[key]] + raise ValueError("Invalid item: " + key + + " . Expect to be a Operation or Tensor") + + def __str__(self): + return str(self.state_object) + + def __eq__(self, other): + return _ffi_api.StateEqual(self.state_object, other.state_object) diff --git a/python/tvm/auto_schedule/measure.py b/python/tvm/auto_schedule/measure.py new file mode 100644 index 000000000000..24e2af1d8f49 --- /dev/null +++ b/python/tvm/auto_schedule/measure.py @@ -0,0 +1,480 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Distributed measurement infrastructure to measure the runtime costs of tensor programs. + +These functions are responsible for building the tvm module, uploading it to +remote devices, recording the running time costs, and checking the correctness of the output. + +We separate the measurement into two steps: build and run. +A builder builds the executable binary files and a runner runs the binary files to +get the measurement results. The flow of data structures is + + `ProgramBuilder` `ProgramRunner` +`MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult` + +We implement these in python to utilize python's multiprocessing and error handling. +""" + +import os +import time +import shutil +import traceback +import tempfile +import multiprocessing + +import tvm._ffi +from tvm.runtime import Object, module, ndarray +from tvm.driver import build_module +from tvm.ir import transform +from tvm.contrib import tar, ndk + +from . import _ffi_api +from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout + +# The maximum length of error message +MAX_ERROR_MSG_LEN = 512 + +# We use fork and a global variable to copy arguments between processings. +# This can avoid expensive serialization of TVM IR when using multiprocessing.Pool +GLOBAL_BUILD_ARGUMENTS = None + +@tvm._ffi.register_object("auto_schedule.MeasureCallback") +class MeasureCallback(Object): + """ The base class of measurement callback functions. """ + + +@tvm._ffi.register_object("auto_schedule.MeasureInput") +class MeasureInput(Object): + """ Store the input of a measurement. + + Parameters + ---------- + task : SearchTask + The SearchTask of this measure. + state : State + The State to be measured. + """ + def __init__(self, task, state): + self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object) + + +@tvm._ffi.register_object("auto_schedule.BuildResult") +class BuildResult(Object): + """ Store the result of a build. + + Parameters + ---------- + filename : Optional[str] + The filename of built binary file. + args : List[Tensor] + The arguments. + error_no : int + The error code. + error_msg : Optional[str] + The error message if there is any error. + time_cost : float + The time cost of build. + """ + def __init__(self, filename, args, error_no, error_msg, time_cost): + filename = filename if filename else "" + error_msg = error_msg if error_msg else "" + + self.__init_handle_by_constructor__( + _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost) + + +@tvm._ffi.register_object("auto_schedule.MeasureResult") +class MeasureResult(Object): + """ Store the results of a measurement. + + Parameters + ---------- + costs : List[float] + The time costs of execution. + error_no : int + The error code. + error_msg : Optional[str] + The error message if there is any error. + all_cost : float + The time cost of build and run. + timestamp : float + The time stamps of this measurement. + """ + def __init__(self, costs, error_no, error_msg, all_cost, timestamp): + error_msg = error_msg if error_msg else "" + + self.__init_handle_by_constructor__( + _ffi_api.MeasureResult, costs, error_no, + error_msg, all_cost, timestamp) + + +@tvm._ffi.register_object("auto_schedule.ProgramBuilder") +class ProgramBuilder(Object): + """ The base class of ProgramBuilders. """ + + def build(self, measure_inputs, verbose=1): + """ Build programs and return results. + + Parameters + ---------- + measure_inputs : List[MeasureInput] + A List of MeasureInput. + verbose: int = 1 + Verbosity level. 0 for silent, 1 to output information during program building. + + Returns + ------- + res : List[BuildResult] + """ + return _ffi_api.ProgramBuilderBuild(self, measure_inputs, verbose) + + +@tvm._ffi.register_object("auto_schedule.ProgramRunner") +class ProgramRunner(Object): + """ The base class of ProgramRunners. """ + + def run(self, measure_inputs, build_results, verbose=1): + """ Run measurement and return results. + + Parameters + ---------- + measure_inputs : List[MeasureInput] + A List of MeasureInput. + build_results : List[BuildResult] + A List of BuildResult to be ran. + verbose: int = 1 + Verbosity level. 0 for silent, 1 to output information during program running. + + Returns + ------- + res : List[MeasureResult] + """ + return _ffi_api.ProgramRunnerRun(self, measure_inputs, build_results, verbose) + + +@tvm._ffi.register_object("auto_schedule.LocalBuilder") +class LocalBuilder(ProgramBuilder): + """ LocalBuilder use local CPU cores to build programs in parallel. + + Parameters + ---------- + timeout : int = 15 + The timeout limit (in second) for each build thread. + This is used in a wrapper of the multiprocessing.Process.join(). + n_parallel : int = multiprocessing.cpu_count() + Number of threads used to build in parallel. + build_func : str = 'default' + The name of registered build function. + """ + + def __init__(self, + timeout=15, + n_parallel=multiprocessing.cpu_count(), + build_func='default'): + self.__init_handle_by_constructor__( + _ffi_api.LocalBuilder, timeout, n_parallel, build_func) + + +@tvm._ffi.register_object("auto_schedule.LocalRunner") +class LocalRunner(ProgramRunner): + """ LocalRunner that uses local CPU/GPU to measures the time cost of programs. + + Parameters + ---------- + timeout : int = 10 + The timeout limit (in second) for each run. + This is used in a wrapper of the multiprocessing.Process.join(). + number : int = 3 + The number of times to run the generated code for taking average. + We call these runs as one `repeat` of measurement. + repeat : int = 1 + The number of times to repeat the measurement. + In total, the generated code will be run (1 + number x repeat) times, + where the first "1" is warm up and will be discarded. + The returned result contains `repeat` costs, + each of which is an average of `number` costs. + min_repeat_ms : int = 0 + The minimum duration of one `repeat` in milliseconds. + By default, one `repeat` contains `number` runs. If this parameter is set, + the parameters `number` will be dynamically adjusted to meet the + minimum duration requirement of one `repeat`. + i.e., When the run time of one `repeat` falls below this time, the `number` parameter + will be automatically increased. + cooldown_interval : float = 0.0 + The cool down interval between two measurements. + """ + + def __init__(self, + timeout=10, + number=3, + repeat=1, + min_repeat_ms=0, + cooldown_interval=0.0): + self.__init_handle_by_constructor__( + _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) + + +class MeasureErrorNo(object): + """ Error type for MeasureResult. """ + NO_ERROR = 0 # No error + INSTANTIATION_ERROR = 1 # Errors happen when apply transform steps from init state + # Errors happen when compiling code on host (e.g. tvm.build) + COMPILE_HOST = 2 + COMPILE_DEVICE = 3 # Errors happen when compiling code on device + # (e.g. OpenCL JIT on the device) + RUNTIME_DEVICE = 4 # Errors happen when run program on device + WRONG_ANSWER = 5 # Answer is wrong when compared to a reference output + BUILD_TIMEOUT = 6 # Timeout during compilation + RUN_TIMEOUT = 7 # Timeout during run + UNKNOWN_ERROR = 8 # Unknown error + + +def make_error_msg(): + """ Get the error message from traceback. """ + error_msg = str(traceback.format_exc()) + if len(error_msg) > MAX_ERROR_MSG_LEN: + error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \ + "\n...\n" + error_msg[-MAX_ERROR_MSG_LEN//2:] + return error_msg + + +def local_build_worker(index): + """ + Build function of LocalBuilder to be ran in the Builder thread pool. + + Parameters + ---------- + index : int + The MeasureInput index to be processed by the current Builder thread. + + Returns + ------- + res : BuildResult + The build result of this Builder thread. + """ + global GLOBAL_BUILD_ARGUMENTS + + # We use fork and a global variable to copy arguments between processings. + # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool + if not GLOBAL_BUILD_ARGUMENTS: + raise ValueError("GLOBAL_BUILD_ARGUMENTS not found") + measure_inputs, build_func, timeout, verbose = GLOBAL_BUILD_ARGUMENTS + assert isinstance(build_func, str) + + if build_func == 'default': + build_func = tar.tar + elif build_func == 'ndk': + build_func = ndk.create_shared + else: + raise ValueError("Invalid build_func" + build_func) + + def timed_func(): + tic = time.time() + inp = measure_inputs[index] + task = inp.task + + error_no = MeasureErrorNo.NO_ERROR + error_msg = None + args = [] + + try: + sch, args = task.compute_dag.apply_steps_from_state( + inp.state) + # pylint: disable=broad-except + except Exception: + error_no = MeasureErrorNo.INSTANTIATION_ERROR + error_msg = make_error_msg() + + if error_no == 0: + dirname = tempfile.mkdtemp() + filename = os.path.join( + dirname, "tmp_func." + build_func.output_format) + + try: + with transform.PassContext(): # todo(lmzheng): port the unroll pass + func = build_module.build( + sch, args, target=task.target, target_host=task.target_host) + func.export_library(filename, build_func) + # pylint: disable=broad-except + except Exception: + error_no = MeasureErrorNo.COMPILE_HOST + error_msg = make_error_msg() + else: + filename = "" + + if verbose >= 1: + if error_no == MeasureErrorNo.NO_ERROR: + print(".", end="") + else: + print(".E", end="") # Build error + return filename, args, error_no, error_msg, time.time() - tic + + res = call_func_with_timeout(timeout, timed_func) + if isinstance(res, TimeoutError): + if verbose >= 1: + print(".T", end="") # Build timeout + res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout + + return res + + +@tvm._ffi.register_func("auto_schedule.local_builder.build") +def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbose=1): + """ + Build function of LocalBuilder to build the MeasureInputs to runnable modules. + + Parameters + ---------- + inputs : List[MeasureInput] + The MeasureInputs to be built. + timeout : int + The timeout limit (in second) for each build thread. + This is used in a wrapper of the multiprocessing.Process.join(). + n_parallel : int + Number of threads used to build in parallel. + build_func : str = 'default' + The name of build function to process the built module. + verbose: int = 1 + Verbosity level. 0 for silent, 1 to output information during program building. + + Returns + ------- + res : List[BuildResult] + The build results of these MeasureInputs. + """ + # We use fork and a global variable to copy arguments between processings. + # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool + global GLOBAL_BUILD_ARGUMENTS + + GLOBAL_BUILD_ARGUMENTS = (inputs, build_func, timeout, verbose) + + pool = NoDaemonPool(n_parallel) + tuple_res = pool.map(local_build_worker, range(len(inputs))) + pool.terminate() + pool.join() + del pool + + results = [] + for res in tuple_res: + results.append(BuildResult(*res)) + + return results + +@tvm._ffi.register_func("auto_schedule.local_runner.run") +def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval, + verbose=1): + """ + Run function of LocalRunner to test the performance of the input BuildResults. + + Parameters + ---------- + inputs : List[MeasureInput] + The MeasureInputs to be measured. + build_results : List[BuildResult] + The BuildResults to be measured. + timeout : int + The timeout limit (in second) for each run. + This is used in a wrapper of the multiprocessing.Process.join(). + number : int = 3 + The number of times to run the generated code for taking average. + We call these runs as one `repeat` of measurement. + repeat : int = 1 + The number of times to repeat the measurement. + In total, the generated code will be run (1 + number x repeat) times, + where the first "1" is warm up and will be discarded. + The returned result contains `repeat` costs, + each of which is an average of `number` costs. + min_repeat_ms : int = 0 + The minimum duration of one `repeat` in milliseconds. + By default, one `repeat` contains `number` runs. If this parameter is set, + the parameters `number` will be dynamically adjusted to meet the + minimum duration requirement of one `repeat`. + i.e., When the run time of one `repeat` falls below this time, the `number` parameter + will be automatically increased. + cooldown_interval : float = 0.0 + The cool down interval between two measurements. + verbose: int = 1 + Verbosity level. 0 for silent, 1 to output information during program measuring. + + Returns + ------- + res : List[MeasureResult] + The measure results of these MeasureInputs. + """ + max_float = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log + + def timed_func(inp, build_res): + tic = time.time() + error_no = 0 + error_msg = None + try: + func = module.load_module(build_res.filename) + ctx = ndarray.context(str(inp.task.target), 0) + time_f = func.time_evaluator( + func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) + # pylint: disable=broad-except + except Exception: + costs = (max_float,) + error_no = MeasureErrorNo.COMPILE_DEVICE + error_msg = make_error_msg() + + if error_no == 0: + try: + args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + build_res.args] + ctx.sync() + costs = time_f(*args).results + # pylint: disable=broad-except + except Exception: + costs = (max_float,) + error_no = MeasureErrorNo.RUNTIME_DEVICE + error_msg = make_error_msg() + + shutil.rmtree(os.path.dirname(build_res.filename)) + toc = time.time() + time.sleep(cooldown_interval) + + if verbose >= 1: + if error_no == MeasureErrorNo.NO_ERROR: + print("*", end="") + else: + print("*E", end="") # Run error + return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc + + measure_results = [] + assert len(inputs) == len(build_results), \ + "Measure input size should be equal to build results" + for inp, build_res in zip(inputs, build_results): + if build_res.error_no != 0: + res = (max_float,), build_res.error_no, build_res.error_msg, build_res.time_cost, \ + time.time() + else: + res = call_func_with_timeout( + timeout, timed_func, args=(inp, build_res)) + if isinstance(res, TimeoutError): + if verbose >= 1: + print("*T", end="") # Run timeout + res = (max_float,), MeasureErrorNo.RUN_TIMEOUT, None, \ + build_res.time_cost + timeout, time.time() + measure_results.append(MeasureResult(*res)) + + if verbose >= 1: + print("") + + return measure_results diff --git a/python/tvm/auto_schedule/measure_record.py b/python/tvm/auto_schedule/measure_record.py new file mode 100644 index 000000000000..25a998566280 --- /dev/null +++ b/python/tvm/auto_schedule/measure_record.py @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Serialization and other I/O support for measurement records (tuning logs). """ + +import numpy as np + +import tvm._ffi +from tvm.runtime import Object +from .measure import MeasureCallback, MeasureErrorNo +from . import _ffi_api + + +@tvm._ffi.register_object("auto_schedule.RecordToFile") +class RecordToFile(MeasureCallback): + """ + A measurement callback that writes measurement records into a file. + + Parameters + ---------- + filename : str + File name for this callback to write log to. + """ + def __init__(self, filename="auto_schedule_tuning.json"): + self.__init_handle_by_constructor__(_ffi_api.RecordToFile, filename) + + +@tvm._ffi.register_object("auto_schedule.RecordReader") +class RecordReader(Object): + """ + Reader of the json log file. + + Parameters + ---------- + filename : str = "auto_schedule_tuning.json" + File name for this reader to load log from. + """ + def __init__(self, filename="auto_schedule_tuning.json"): + self.__init_handle_by_constructor__(_ffi_api.RecordReader, filename) + + def read_lines(self, max_lines=None, skip_lines=0): + """ Read multiple lines from the log file. + + Parameters + ---------- + max_lines : Optional[int] + The maximum number of lines. None to read all lines. + skip_lines : int = 0 + Skip the first n lines. + + Returns + ------- + inputs : List[MeasureInput] + The MeasureInputs loaded from the log file. + results : List[MeasureResult] + The MeasureResults loaded from the log file. + """ + inputs, results = _ffi_api.RecordReaderReadLines(self, max_lines if max_lines else -1, + skip_lines) + return inputs, results + + def __iter__(self): + while True: + ret = _ffi_api.RecordReaderReadNext(self) + if not ret: + break + yield ret[0], ret[1] # (input, result) + + +def load_records(filename): + """ + Load measurement records from a file. + + Parameters + ---------- + filename : str + File name to load log from. + + Returns + ------- + logs : List[MeasureInput, MeasureResult] + """ + return zip(*RecordReader(filename).read_lines()) + + +def save_records(filename, inputs, results): + """ + Append measure records to file. + + Parameters + ---------- + filename : str + File name to write log to. + inputs: List[MeasureInputs] + The MeasureInputs to be written. + results: List[MeasureResults] + The MeasureResults to be written. + """ + _ffi_api.SaveRecords(filename, inputs, results) + +def load_best(filename, workload_key=None, target=None): + """ Return the best measurement pair form a log file. This may return none results if + there is no legal measure pair with the specified workload_key/target found from the log file. + + Parameters + ---------- + filename : str + File name to load log from. + workload_key : Optional[str] + The workload key of the compute declaration. + With `None`, this retuns the best measure pair of all workloads. + target : Optional[tvm.target.Target] + The target device. + With `None`, this retuns the best measure pair of all target devices. + + Returns + ------- + input : MeasureInput + The best State's MeasureInput from this log fine. + result : MeasureResult + The best State's MeasureResult from this log fine. + """ + log_reader = RecordReader(filename) + best_cost = 1e30 + best_inp = None + best_res = None + + for inp, res in log_reader: + if res.error_no != MeasureErrorNo.NO_ERROR: + continue + if workload_key and inp.task.workload_key != workload_key: + continue + if target and inp.task.target.id.name != target.id.name: + continue + + costs = [v.value for v in res.costs] + cost = np.mean(costs) + if cost < best_cost: + best_cost = cost + best_inp = inp + best_res = res + + return best_inp, best_res diff --git a/python/tvm/auto_schedule/utils.py b/python/tvm/auto_schedule/utils.py new file mode 100644 index 000000000000..c29675074e22 --- /dev/null +++ b/python/tvm/auto_schedule/utils.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Common utilities for auto_schedule. """ + +from typing import Hashable +import multiprocessing +import multiprocessing.pool +import queue +import signal + +try: + import psutil +except ImportError: + raise ImportError("psutil not found, try `pip install psutil` to fix this") + +from tvm.tir import expr +from tvm.tir.transform import Simplify +from tvm.ir.transform import Sequential +from ..te import Tensor, placeholder + + +def get_func_name(func): + """Get name of a function. + + Parameters + ---------- + func: Function + The input function. + + Returns + ------- + name: str + The function name. + """ + return func.func_name if hasattr(func, 'func_name') else func.__qualname__ + + +def get_const_int(exp): + """Verifies expr is integer and get the constant value. + + Parameters + ---------- + exp : Union[tvm.tir.expr, int] + The input expression. + + Returns + ------- + out_value : int + The output. + """ + if isinstance(exp, int): + return exp + if not isinstance(exp, expr.IntImm): + opt = Sequential([Simplify()]) + exp = opt(exp) + if not isinstance(exp, expr.IntImm): + raise ValueError("Expect value to be constant int") + return exp.value + + +def get_const_tuple(in_tuple): + """Verifies input tuple is IntImm, returns tuple of int. + + Parameters + ---------- + in_tuple : Tuple[tvm.tir.expr] + The input. + + Returns + ------- + out_tuple : Tuple[int] + The output. + """ + return tuple(get_const_int(x) for x in in_tuple) + + + +def list_to_tuple(x): + """ Convert a list to a tuple recursively. """ + assert isinstance(x, list) + return tuple(list_to_tuple(y) if isinstance(y, list) else y for y in x) + + +def serialize_args(args): + """ + Serialize arguments of a function to a hashable and jsonable tuple. + Currently this is mainly used for tvm.tensor.Tensor + """ + ret = [] + for t in args: + if isinstance(t, Tensor): + t = ('TENSOR', get_const_tuple(t.shape), t.dtype) + elif isinstance(t, list): + t = list_to_tuple(t) + + assert isinstance(t, Hashable), str(t) + " is not hashable" + ret.append(t) + + return tuple(ret) + + +def deserialize_args(args): + """The inverse function of :code:`serialize_args`""" + ret = [] + for t in args: + if isinstance(t, (tuple, list)) and t[0] == 'TENSOR': + ret.append(placeholder(shape=t[1], dtype=t[2])) + else: + ret.append(t) + return ret + + +class NoDaemonProcess(multiprocessing.Process): + @property + def daemon(self): + return False + + @daemon.setter + def daemon(self, value): + pass + + +class NoDaemonContext(type(multiprocessing.get_context())): + Process = NoDaemonProcess + + +class NoDaemonPool(multiprocessing.pool.Pool): + """A no daemon pool version of multiprocessing.Pool. + This allows us to start new processings inside the worker function""" + + def __init__(self, *args, **kwargs): + kwargs['context'] = NoDaemonContext() + super().__init__(*args, **kwargs) + + def __reduce__(self): + pass + + +def kill_child_processes(parent_pid, sig=signal.SIGTERM): + """kill all child processes recursively""" + try: + parent = psutil.Process(parent_pid) + except psutil.NoSuchProcess: + return + children = parent.children(recursive=True) + for process in children: + try: + process.send_signal(sig) + except psutil.NoSuchProcess: + return + + +def call_func_with_timeout(timeout, func, args=(), kwargs=None): + """Call a function with timeout""" + def func_wrapper(que): + if kwargs: + que.put(func(*args, **kwargs)) + else: + que.put(func(*args)) + + que = multiprocessing.Queue(2) + process = multiprocessing.Process(target=func_wrapper, args=(que,)) + process.start() + process.join(timeout) + + try: + res = que.get(block=False) + except queue.Empty: + res = TimeoutError() + + # clean queue and process + kill_child_processes(process.pid) + process.terminate() + process.join() + que.close() + que.join_thread() + del process + del que + + return res diff --git a/python/tvm/auto_schedule/workload_registry.py b/python/tvm/auto_schedule/workload_registry.py new file mode 100644 index 000000000000..b50727ec955e --- /dev/null +++ b/python/tvm/auto_schedule/workload_registry.py @@ -0,0 +1,191 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Workload registration and serialization. + +We use a json string to represent a workload (a computation graph). +The format of the string is `[func_name, [args...]]`. +The dag should be the return value of this `func_name(*args)`. + +Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags +and matching them efficiently is not easy. Therefore, we use the above string to encode a compute +dag. +These strings are efficient for serialization/matching and won't be too long. +When we need the dag, we decode the string and call the function, which will return the dag. +""" + +import pickle +import json + +import tvm._ffi +from .utils import serialize_args, deserialize_args, get_func_name + +WORKLOAD_FUNC_REGISTRY = {} + + +def register_workload(func_name, f=None, override=False): + """ Register a function that generates a certain workload. + + The input function should take hashable and jsonable arguments + (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. + + Parameters + ---------- + func_name : Union[Function, str] + The generation function that returns the compute declaration Tensors or its function name. + f : Optional[Function] + The generation function to be registered. + override : boolean = False + Whether override existing entry. + + Examples + -------- + @auto_schedule.register_workload + def matmul(N, M, K): + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') + return [A, B, C] + """ + global WORKLOAD_FUNC_REGISTRY + + if callable(func_name): + f = func_name + func_name = get_func_name(f) + if not isinstance(func_name, str): + raise ValueError("expect string function name") + + def register(myf): + """internal register function""" + if func_name in WORKLOAD_FUNC_REGISTRY and not override: + raise RuntimeError('%s has been registered already' % func_name) + WORKLOAD_FUNC_REGISTRY[func_name] = myf + return myf + if f: + return register(f) + return register + + +def make_workload_key(func, args): + """ Make a workload key by function and arguments. + + Parameters + ---------- + func : Union[Function, str] + The function that returns the compute declaration Tensors. + Can be the a function or the function name. + args : Args + The args of the function. + + Returns + ------- + workload_key : Str + The workload key of the function. + """ + global WORKLOAD_FUNC_REGISTRY + + if callable(func): + func_name = get_func_name(func) + elif isinstance(func, str): + func_name = func + else: + raise ValueError("Invalid function: " + str(func) + + " . `make_workload_key` expects a callable function or its function name") + + if not func_name in WORKLOAD_FUNC_REGISTRY: + raise ValueError("%s is not registered. " % func, + "Please register it with @auto_schedule.register_workload") + + args = serialize_args(args) + + return json.dumps((func_name,) + args) + + +def decode_workload_key_to_func_args(workload_key): + """ Decode a workload key to the registerd function name and its corresponding args. + + Parameters + ---------- + workload_key : str + The input workload key. + + Returns + ------- + name : str + The function name of this workload key. + args : List[Tensor] + The args of the generation function. + """ + global WORKLOAD_FUNC_REGISTRY + + workload = json.loads(workload_key) + if not workload[0] in WORKLOAD_FUNC_REGISTRY: + raise ValueError("%s is not registered. " % workload[0] + + "Please register it with @auto_schedule.register_workload") + return workload[0], deserialize_args(workload[1:]) + + +@tvm._ffi.register_func("auto_schedule.workload_key_to_tensors") +def workload_key_to_tensors(workload_key): + """ Get the input/output tensors from the workload key. + + This method is usually used to create a ComputeDAG by workload key. + + Parameters + ---------- + workload_key : str + The input workload key. + + Returns + ------- + tensors : List[Tensor] + The registered compute declaration Tensors. + """ + global WORKLOAD_FUNC_REGISTRY + + name, args = decode_workload_key_to_func_args(workload_key) + lookup = WORKLOAD_FUNC_REGISTRY[name] + assert callable(lookup) + return lookup(*args) + + +def save_workload_func_registry(filename): + """ Dump workload function registry to a pickle binary file. + + Parameters + ---------- + filename : str + The filename to dump workload function registry to. + """ + global WORKLOAD_FUNC_REGISTRY + + pickle.dump(WORKLOAD_FUNC_REGISTRY, open(filename, 'wb')) + + +def load_workload_func_registry(filename): + """ Load workload function registry from a pickle binary file. + + Parameters + ---------- + filename : str + The filename to load workload function registry from. + """ + global WORKLOAD_FUNC_REGISTRY + + WORKLOAD_FUNC_REGISTRY = pickle.load(open(filename, 'rb')) diff --git a/src/auto_schedule/auto_schedule.cc b/src/auto_schedule/auto_schedule.cc new file mode 100644 index 000000000000..aaf472b1f26a --- /dev/null +++ b/src/auto_schedule/auto_schedule.cc @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/auto_schedule.cc + * \brief The user interface of the TVM Auto-scheduler. + */ + +#include "auto_schedule.h" + +#include + +namespace tvm { +namespace auto_schedule { + +TVM_REGISTER_NODE_TYPE(TuningOptionsNode); + +TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, + int verbose, ProgramBuilder builder, ProgramRunner runner, + Optional> measure_callbacks, + Optional> pre_search_callbacks) { + auto node = make_object(); + node->num_measure_trials = num_measure_trials; + node->early_stopping = early_stopping; + node->num_measures_per_round = num_measures_per_round; + node->verbose = verbose; + node->builder = std::move(builder); + node->runner = std::move(runner); + node->measure_callbacks = std::move(measure_callbacks); + node->pre_search_callbacks = std::move(pre_search_callbacks); + data_ = std::move(node); +} + +std::pair> AutoSchedule(SearchTask task, SearchPolicy search_policy, + TuningOptions tuning_options) { + // Create a ProgramMeasurer to handle the schedule build and performance measure + ProgramMeasurer measurer = + ProgramMeasurer(tuning_options->builder, tuning_options->runner, + tuning_options->measure_callbacks, tuning_options->verbose); + // Search for the best schedule + State state = search_policy->Search( + task, tuning_options->num_measure_trials, tuning_options->early_stopping, + tuning_options->num_measures_per_round, tuning_options->verbose, measurer, + tuning_options->pre_search_callbacks); + return task->compute_dag.ApplySteps(state->transform_steps); +} + +TVM_REGISTER_GLOBAL("auto_schedule.TuningOptions") + .set_body_typed([](int num_measure_trials, int early_stopping, int num_measures_per_round, + int verbose, ProgramBuilder builder, ProgramRunner runner, + Optional> measure_callbacks, + Optional> pre_search_callbacks) { + return TuningOptions(num_measure_trials, early_stopping, num_measures_per_round, verbose, + builder, runner, measure_callbacks, pre_search_callbacks); + }); + +TVM_REGISTER_GLOBAL("auto_schedule.AutoSchedule") + .set_body_typed([](SearchTask task, SearchPolicy search_policy, TuningOptions tuning_options) { + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = AutoSchedule(task, search_policy, tuning_options); + return Array{sch, return_tensors}; + }); +} // namespace auto_schedule +} // namespace tvm diff --git a/src/auto_schedule/auto_schedule.h b/src/auto_schedule/auto_schedule.h new file mode 100644 index 000000000000..d2e49bbe7e4f --- /dev/null +++ b/src/auto_schedule/auto_schedule.h @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/auto_schedule.h + * \brief The user interface of the TVM Auto-scheduler. This is the entry structure to get + * schedule search requirements from upper level (Python API), and returns a high performance + * schedule after search process. + */ + +#ifndef TVM_AUTO_SCHEDULE_AUTO_SCHEDULE_H_ +#define TVM_AUTO_SCHEDULE_AUTO_SCHEDULE_H_ + +#include + +#include "measure.h" +#include "search_policy/search_policy.h" + +namespace tvm { +namespace auto_schedule { + +/*! \brief Tuning and measurement options. */ +class TuningOptionsNode : public Object { + public: + /*! \brief Number of total measurement trials. */ + int num_measure_trials; + /*! \brief Stops early the tuning if no improvement after n measurements. */ + int early_stopping; + /*! \brief The number of programs to be measured at each search round. */ + int num_measures_per_round; + /*! + * \brief Verbosity level. + * 0 for silent, 1 to output information during schedule searching. + */ + int verbose; + /*! \brief ProgramBuilder which builds the program */ + ProgramBuilder builder; + /*! \brief ProgramRunner which runs the program and measure time costs */ + ProgramRunner runner; + /*! \brief MeasureCallback functions to be called after each measure batch */ + Optional> measure_callbacks; + /*! \brief SearchCallback functions to be called before schedule search */ + Optional> pre_search_callbacks; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_measure_trials", &num_measure_trials); + v->Visit("early_stopping", &early_stopping); + v->Visit("num_measures_per_round", &num_measures_per_round); + v->Visit("verbose", &verbose); + v->Visit("builder", &builder); + v->Visit("runner", &runner); + v->Visit("measure_callbacks", &measure_callbacks); + v->Visit("pre_search_callbacks", &pre_search_callbacks); + } + + static constexpr const char* _type_key = "auto_schedule.TuningOptions"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuningOptionsNode, Object); +}; + +/*! + * \brief Managed reference to TuningOptionsNode. + * \sa TuningOptionsNode + */ +class TuningOptions : public ObjectRef { + public: + /*! + * \brief The constructor + * \param num_measure_trials Number of total measurement trials. + * \param early_stopping Stops early the tuning if no improvement after n measurements. + * \param num_measures_per_round The number of programs to be measured at each search round. + * \param verbose Verbosity level. 0 for silent, 1 to output information during schedule + * search. + * \param builder ProgramBuilder which builds the program. + * \param runner ProgramRunner which runs the program and measure time costs. + * \param measure_callbacks MeasureCallback functions to be called after each measure batch. + * \param pre_search_callbacks SearchCallback functions to be called before schedule search. + */ + TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, + ProgramBuilder builder, ProgramRunner runner, + Optional> measure_callbacks, + Optional> pre_search_callbacks); + + TVM_DEFINE_OBJECT_REF_METHODS(TuningOptions, ObjectRef, TuningOptionsNode); +}; + +/*! + * \brief Auto schedule search for a given compute declaration. + * \param task The search task of the compute declaration. + * \param search_policy The search policy to be used for schedule search. + * \param tuning_options Tuning and measurement options. + * \return A `te::schedule` and the a Array of `te::Tensor` to be used in `tvm.lower` or + * `tvm.build`. + */ +TVM_DLL std::pair> AutoSchedule(SearchTask task, + SearchPolicy search_policy, + TuningOptions tuning_options); +} // namespace auto_schedule +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULE_AUTO_SCHEDULE_H_ diff --git a/src/auto_schedule/compute_dag.cc b/src/auto_schedule/compute_dag.cc new file mode 100644 index 000000000000..312a25ad62dd --- /dev/null +++ b/src/auto_schedule/compute_dag.cc @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/compute_dag.cc + * \brief Compute declaration graph and its related analysis tools. + */ + +#include "compute_dag.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "loop_state.h" +#include "utils.h" + +namespace tvm { +namespace auto_schedule { + +using namespace tvm::tir; + +TVM_REGISTER_NODE_TYPE(ComputeDAGNode); + +// Topo-sort ops from tensors according to their read-write relations. +Array TopoSortOps(const Array& tensors) { + std::unordered_map degree; + std::unordered_map> edge_set; + std::unordered_map priority; + std::unordered_set visited; + + // traverse to build edge_set and count degree + std::vector stack; + stack.reserve(tensors.size()); + for (const auto& x : tensors) { + stack.push_back(x->op.operator->()); + } + + int ct = 0; + while (!stack.empty()) { + const te::OperationNode* op = stack.back(); + stack.pop_back(); + if (visited.count(op)) { + continue; + } + + priority[op] = ct; + ct++; + visited.insert(op); + + if (op->IsInstance()) { + degree[op] = 0; + } else if (auto cop = GetRef(op).as()) { + const Array& input_tensors = cop->InputTensors(); + degree[op] = input_tensors.size(); + for (const auto& ten : input_tensors) { + edge_set[ten->op.operator->()].push_back(op); + stack.push_back(ten->op.operator->()); + } + } else { + LOG(FATAL) << "Unsupported op " << GetRef(op); + } + } + + // topo sort + Array ops; + + using Item = std::pair; + auto cmp = [](const Item& left, const Item& right) { return left.second < right.second; }; + std::priority_queue, decltype(cmp)> queue(cmp); + for (const auto& iter : degree) { + if (iter.second == 0) { + queue.push(Item(iter.first, priority[iter.first])); + } + } + + ops.reserve(degree.size()); + while (!queue.empty()) { + Item item = queue.top(); + queue.pop(); + ops.push_back(GetRef(item.first)); + for (const auto& dst : edge_set[item.first]) { + degree[dst] -= 1; + if (degree[dst] == 0) { + queue.push(Item(dst, priority[dst])); + } + } + } + + return ops; +} + +// Estimate number of float operations in an expression +class FlopEstimator : public ExprFunctor { + public: + double EstimateFlop(const Array& ops) { + double ret = 0; + for (const auto& op : ops) { + if (auto pop = op.as()) { + double num_element = AxisLengthProd(pop->axis); + if (num_element == -1) { + fail_ = true; + break; + } + double op_per_element = 0; + for (const auto& x : pop->body) { + op_per_element += VisitExpr(x); + } + ret += num_element * op_per_element; + } else if (op->IsInstance()) { + {} // do nothing + } else { + LOG(FATAL) << "Invalid op type " << op; + } + } + + return fail_ ? -1 : ret; + } + + double VisitExpr_(const ReduceNode* op) final { + uint64_t num_iter = 1; + for (const auto& x : op->axis) { + if (auto imm = x->dom->extent.as()) { + num_iter *= imm->value; + } else { + fail_ = true; + num_iter = -1; + } + } + double body_flop = 0; + for (size_t i = 0; i < op->combiner->result.size(); ++i) { + body_flop += VisitExpr(op->combiner->result[i]); + body_flop += VisitExpr(op->source[i]); + } + return num_iter * body_flop; + } + + double VisitExpr_(const FloatImmNode* op) final { return 0.0; } + double VisitExpr_(const IntImmNode* op) final { return 0.0; } + double VisitExpr_(const ProducerLoadNode* op) final { return 0.0; } + + double VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } + double VisitExpr_(const VarNode* op) final { return 0.0; } + + double VisitExpr_(const SelectNode* op) final { + return VisitExpr(op->condition) + + std::max(VisitExpr(op->true_value), VisitExpr(op->false_value)); + } + +#define VisitBinary(Node) \ + double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a) + VisitExpr(op->b); } +#define VisitUnary(Node) \ + double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a); } + + VisitBinary(AddNode); + VisitBinary(SubNode); + VisitBinary(MulNode); + VisitBinary(DivNode); + VisitBinary(ModNode); + VisitBinary(FloorDivNode); + VisitBinary(FloorModNode); + VisitBinary(MaxNode); + VisitBinary(MinNode); + VisitBinary(EQNode); + VisitBinary(NENode); + VisitBinary(LTNode); + VisitBinary(LENode); + VisitBinary(GTNode); + VisitBinary(GENode); + VisitBinary(AndNode); + VisitBinary(OrNode); + VisitUnary(NotNode); + + double VisitExpr_(const CallNode* op) final { + double ret = 0.0; + for (const auto& x : op->args) { + ret += VisitExpr(x); + } + return ret; + } + + double VisitExprDefault_(const Object* op) final { + fail_ = true; + return -1.0; + } + + private: + bool fail_{false}; +}; + +ComputeDAG::ComputeDAG(Array tensors) { + auto node = make_object(); + node->tensors = std::move(tensors); + node->ops = TopoSortOps(node->tensors); + node->flop_ct = FlopEstimator().EstimateFlop(node->ops); + node->init_state = State(node->ops); + data_ = std::move(node); +} + +// Update the te::stage to tir::IterVar axis mapping +void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) { + if (auto pop = stage->op.as()) { + Array axes; + for (const auto& axis : pop->axis) { + axes.push_back(axis); + } + for (const auto& axis : pop->reduce_axis) { + axes.push_back(axis); + } + stage_to_axes->Set(stage, std::move(axes)); + } else if (stage->op->IsInstance()) { + {} // do nothing on Placeholder + } else { + LOG(FATAL) << "Invalid op " << stage->op; + } +} + +std::pair> ComputeDAG::ApplySteps( + const Array& transform_steps, Array* stages, + StageToAxesMap* stage_to_axes) const { + // Temporal object to be used if the input pointer is nullptr + Array temp_stages; + StageToAxesMap temp_stage_to_axes; + if (stages == nullptr) { + stages = &temp_stages; + } + if (stage_to_axes == nullptr) { + stage_to_axes = &temp_stage_to_axes; + } + Array ops; + for (const auto& op : operator->()->ops) { + if (!op->IsInstance()) { + ops.push_back(op); + } + } + // Create the initial schedule + // TODO(jcf94): Currently we only checked single output dag for TVM Auto-scheduler, + // update this after testing with multiple outputs. + te::Schedule schedule = te::create_schedule({ops.back()}); + + // init axes + for (const auto& x : operator->()->ops) { + const te::Stage& stage = schedule[x]; + stages->push_back(stage); + UpdateStageToAxesMap(stage, stage_to_axes); + } + + // Apply the history steps to TVM schedule + for (const auto& step : transform_steps) { + // Call each step's ApplyToSchedule method + // Note: some steps have extra parameters that must be passed and they may need different + // return value, so the ApplyToSchedule is not able to be merged to single interface + if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else { + LOG(FATAL) << "Invalid Step"; + } + } + + return std::make_pair(schedule, operator->()->tensors); +} + +String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const { + Array stages; + StageToAxesMap stage_to_axes; + Array ops; + for (const auto& op : operator->()->ops) { + if (!op->IsInstance()) { + ops.push_back(op); + } + } + // Create the initial schedule + // TODO(jcf94): Currently we only checked single output dag for TVM Auto-scheduler, + // update this after testing with multiple outputs. + te::Schedule schedule = te::create_schedule({ops.back()}); + + // init axes + for (const auto& x : operator->()->ops) { + const te::Stage& stage = schedule[x]; + stages.push_back(stage); + UpdateStageToAxesMap(stage, &stage_to_axes); + } + + std::stringstream ss; + for (const auto& stage : stages) { + if (stage->op->IsInstance()) { + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { + ss << stage->leaf_iter_vars[i]->var->name_hint; + if (i != stage->leaf_iter_vars.size() - 1) { + ss << ", "; + } + } + ss << " = " + << "tuple(" << stage->op->name << ".op.axis)" + << " + " + << "tuple(" << stage->op->name << ".op.reduce_axis)\n"; + } + } + // Call each step's PrintAsPythonAPI method + for (const auto& step : transform_steps) { + if (auto ps = step.as()) { + ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + } else if (auto ps = step.as()) { + ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + } else if (auto ps = step.as()) { + ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + } else { + LOG(FATAL) << "Invalid Step"; + } + } + + return ss.str(); +} + +State ComputeDAG::InferBound(const State& state) const { + CHECK(state->concrete) << "Only concrete state can be processed to get bound info."; + + State ret_state; + StateNode* pstate; + + if (state->stages.empty()) { + // If the input state is incomplete with empty operation stage + // create a new state from init_state and update it first + ret_state = operator->()->init_state; + pstate = ret_state.CopyOnWrite(); + pstate->transform_steps = state->transform_steps; + ret_state.DoSteps(*this); + } else { + ret_state = state; + pstate = ret_state.CopyOnWrite(); + } + + Array stages; + StageToAxesMap stage_to_axes; + te::Schedule sch; + Array tensors; + // Replay steps to tvm::Schedule + std::tie(sch, tensors) = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes); + sch = sch.normalize(); + // Get bound information from TVM schedule + Map bounds = te::InferBound(sch); + + // Update the state bound information + for (size_t i = 0; i < pstate->stages.size(); ++i) { + const Stage& stage = pstate->stages[i]; + + if (stage->compute_at == ComputeAtKind::kInlined) { + continue; + } + + Array new_iters; + new_iters.reserve(stage->iters.size()); + // Get bound information from schedule + // the StageToAxesMap is used to find the corresponding IterVar in TVM schedule result + for (size_t j = 0; j < stage->iters.size(); ++j) { + const Iterator& iter = stage->iters[j]; + const IterVar& axis = stage_to_axes.at(stages[i])[j]; + + auto find_res = bounds.find(axis); + if (find_res != bounds.end()) { + new_iters.push_back( + Iterator(iter->name, (*find_res).second, iter->iter_kind, iter->annotation)); + } else { + LOG(FATAL) << "Infer bound fails"; + } + } + + pstate->stages.Set( + i, Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); + } + + return ret_state; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + std::stringstream ss; + + for (const auto& op : node->ops) { + if (op->IsInstance()) { + ss << op->name << " = PLACEHOLDER " << op.output(0)->shape << "\n"; + } else if (auto pop = op.as()) { + for (size_t k = 0; k < pop->body.size(); ++k) { + ss << op->name << "("; + for (size_t i = 0; i < pop->axis.size(); i++) { + ss << pop->axis[i]->var->name_hint; + if (i != pop->axis.size() - 1) { + ss << ", "; + } + } + ss << ")"; + if (pop->body.size() > 1) { + ss << ".v" << k; + } + if (auto preduce = pop->body[k].as()) { + CHECK_LT(k, preduce->combiner->result.size()); + PrimExpr combiner = preduce->combiner->result[k]; + if (combiner->IsInstance()) { + ss << " += " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + ss << " max= " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + ss << " min= " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + const auto& select = combiner.as(); + ss << " select(" << select->condition << ", " << select->true_value << ", " + << select->false_value << ")= " << '(' << preduce->source[0] << ',' + << preduce->source[1] << ")\n"; + } else { + LOG(FATAL) << "Unsupported reduction operator" << combiner; + } + } else { + ss << " = " << pop->body[k] << "\n"; + } + } + } else { + LOG(FATAL) << "Invalid op"; + } + } + + p->stream << ss.str(); + }); + +TVM_REGISTER_GLOBAL("auto_schedule.ComputeDAG").set_body_typed([](Array tensors) { + return ComputeDAG(tensors); +}); + +TVM_REGISTER_GLOBAL("auto_schedule.ComputeDAGApplyStepsFromState") + .set_body_typed([](const ComputeDAG& dag, const State& state) { + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps); + return Array{sch, return_tensors}; + }); + +TVM_REGISTER_GLOBAL("auto_schedule.ComputeDAGPrintPythonCodeFromState") + .set_body_typed([](const ComputeDAG& dag, const State& state) { + return dag.PrintStepsAsPython(state->transform_steps); + }); + +TVM_REGISTER_GLOBAL("auto_schedule.ComputeDAGInferBoundFromState") + .set_body_typed([](const ComputeDAG& dag, const State& state) { + return dag.InferBound(state); + }); + +} // namespace auto_schedule +} // namespace tvm diff --git a/src/auto_schedule/compute_dag.h b/src/auto_schedule/compute_dag.h new file mode 100644 index 000000000000..bb582d32ee7e --- /dev/null +++ b/src/auto_schedule/compute_dag.h @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/compute_dag.h + * \brief The TVM Auto-scheduler computational graph and related program analyses. + * + * We convert a compute declaration described by `tvm.compute` (could be a single operator or a + * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration, + * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the + * total float operation count, consumer/producer relations of each operation stage, whether an + * operation stage should be tiled/compute inlined ...). These analyses can help the search policy + * to make decisions during search process. + * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and + * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing + * `LoopState` with extra information got from TVM schedule ...). + */ + +#ifndef TVM_AUTO_SCHEDULE_COMPUTE_DAG_H_ +#define TVM_AUTO_SCHEDULE_COMPUTE_DAG_H_ + +#include + +#include + +#include "loop_state.h" + +namespace tvm { +namespace auto_schedule { + +/*! \brief The TVM Auto-scheduler computational graph and related program analyses. */ +class ComputeDAGNode : public Object { + public: + /*! + * \brief Input and output tensors. + * This is used as the input of `tvm.lower` or `tvm.build`. + */ + Array tensors; + /*! \brief All related operations in topo order. */ + Array ops; + /*! \brief Number of total float operations for this ComputeDAG. */ + double flop_ct; + /*! \brief The initial state without any transform steps. */ + State init_state; + // TODO(merrymercy): Add more analyses later. + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tensors", &tensors); + v->Visit("ops", &ops); + v->Visit("flop_ct", &flop_ct); + v->Visit("init_state", &init_state); + } + + static constexpr const char* _type_key = "auto_schedule.ComputeDAG"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object); +}; + +/*! + * \brief Managed reference to ComputeDAGNode. + * \sa ComputeDAGNode + */ +class ComputeDAG : public ObjectRef { + public: + /*! \brief The constructor. + * \param tensors `te::Tensor`s for a compute declaration. + */ + explicit ComputeDAG(Array tensors); + + /*! + * \brief Apply the history transform steps from a State to get a TVM schedule. + * \param transform_steps Transform steps of a state. + * \param stages A pointer to a `te::Stage` Array, default to be nullptr. + * Pass a valid pointer if these information needs to be used outside this function. + * \param stage_to_axes A pointer to a StageToAxesMap, default to be nullptr. + * Pass a valid pointer if these information needs to be used outside this function. + * \return A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`. + */ + std::pair> ApplySteps( + const Array& transform_steps, Array* stages = nullptr, + StageToAxesMap* stage_to_axes = nullptr) const; + + /*! + * \brief Print transform steps as equivalent python schedule API. + * This can be used for debugging. + * \param transform_steps Transform steps of a state. + * \return The Python schedule code. + */ + String PrintStepsAsPython(const Array& transform_steps) const; + + /*! + * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound. + * The states can lose complete bound information after some transform steps (e.g., compute_at). + * We can call this function to infer and fill all the bound information. + * This function calls TVM InferBound pass internally to get the bound. + * The returned state of this function is guaranteed to have complete iterator extent information. + * \param state The state to. + * \return The State after inferbound. + */ + State InferBound(const State& state) const; + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); +}; + +} // namespace auto_schedule +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULE_COMPUTE_DAG_H_ diff --git a/src/auto_schedule/loop_state.cc b/src/auto_schedule/loop_state.cc new file mode 100644 index 000000000000..666efcebce04 --- /dev/null +++ b/src/auto_schedule/loop_state.cc @@ -0,0 +1,413 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/loop_state.cc + * \brief An lightweight IR (intermediate representation) for loop structures. + * see auto_schedule/loop_state.h for more explanation. + */ + +#include "loop_state.h" + +#include +#include + +#include + +#include "transform_step.h" +#include "utils.h" + +namespace tvm { +namespace auto_schedule { + +TVM_REGISTER_OBJECT_TYPE(StepNode); +TVM_REGISTER_NODE_TYPE(StageNode); +TVM_REGISTER_NODE_TYPE(StateNode); +TVM_REGISTER_NODE_TYPE(IteratorNode); + +/********** Iterator **********/ +Iterator::Iterator(String name, Range range, IteratorKind iter_kind, + IteratorAnnotation annotation) { + auto node = make_object(); + node->name = std::move(name); + node->range = std::move(range); + node->iter_kind = iter_kind; + node->annotation = annotation; + data_ = std::move(node); +} + +/********** Stage **********/ +Stage::Stage(te::Operation op) { + auto node = make_object(); + if (op->IsInstance()) { + node->op_type = StageKind::kCompute; + auto* pop = op.as(); + for (const auto& axis : pop->axis) { + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, + IteratorKind::kSpatial, IteratorAnnotation::kNone)); + } + for (const auto& axis : pop->reduce_axis) { + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, + IteratorKind::kReduction, IteratorAnnotation::kNone)); + } + } else if (op->IsInstance()) { + node->op_type = StageKind::kPlaceholder; + } else { + LOG(FATAL) << "Unsupported operator type" << op->_type_key; + } + + node->compute_at = ComputeAtKind::kRoot; + node->op = std::move(op); + node->attrs.auto_unroll_max_step = 0; + node->attrs.storage_offset = 0; + data_ = std::move(node); +} + +Stage::Stage(te::Operation op, StageKind op_type, const Array& iters, + ComputeAtKind compute_at, StageAttributes attrs) { + auto node = make_object(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = iters; + node->compute_at = compute_at; + node->attrs = attrs; + data_ = std::move(node); +} + +/********** State **********/ +State::State(const Array& ops) { + auto node = make_object(); + for (const auto& op : ops) { + node->stages.push_back(Stage(op)); + } + node->concrete = true; + data_ = std::move(node); +} + +/********** Schedule primitives apis for state **********/ +void State::reorder(int stage_id, const Array& order) { + const Stage& stage = operator->()->stages[stage_id]; + CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " + << "should be specified"; + Array after_ids; + GetIndices(stage->iters, order, &after_ids); + ReorderStep step = ReorderStep(stage_id, after_ids); + CopyOnWrite()->transform_steps.push_back(step); + DoReorderStep(step); +} + +Array State::split(int stage_id, const Iterator& it, + const Array>& lengths, bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + SplitStep step = + SplitStep(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer); + CopyOnWrite()->transform_steps.push_back(step); + return DoSplitStep(step); +} + +Iterator State::fuse(int stage_id, const Array& iters) { + const Stage& stage = operator->()->stages[stage_id]; + Array indices; + GetIndices(stage->iters, iters, &indices); + FuseStep step = FuseStep(stage_id, indices); + CopyOnWrite()->transform_steps.push_back(step); + return DoFuseStep(step); +} + +/********** Step implementations for state **********/ +void State::DoReorderStep(const ReorderStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + Array iters; + for (auto x : step->after_ids) { + iters.push_back(stage->iters[x]); + } + StateNode* pstate = CopyOnWrite(); + pstate->stages.Set(step->stage_id, + Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs)); +} + +// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep +Array State::DoSplitStepCommon(int stage_id, int iter_id, + const Array>& lengths, + bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + const Iterator& it = stage->iters[iter_id]; + bool concrete = true; + + Optional tosplit_min, tosplit_extent; + if (it->range.defined()) { + tosplit_min = it->range->min; + tosplit_extent = it->range->extent; + } else { + tosplit_min = NullOpt; + tosplit_extent = NullOpt; + } + + Array outs; + for (size_t i = 0; i < lengths.size(); ++i) { + Optional l; + String name; + if (inner_to_outer) { + l = lengths[lengths.size() - i - 1]; + name = it->name + "." + std::to_string(lengths.size() - i); + } else { + l = lengths[i]; + name = it->name + "." + std::to_string(i); + } + Iterator res; + if (l && tosplit_min && tosplit_extent) { + res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind, + IteratorAnnotation::kNone); + tosplit_min = Integer(0); + tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value()); + } else { + res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone); + tosplit_min = NullOpt; + tosplit_extent = NullOpt; + concrete = false; + } + outs.push_back(std::move(res)); + } + + Range range; + if (tosplit_min && tosplit_extent) { + range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value()); + } + if (inner_to_outer) { + outs.push_back(Iterator(it->name + ".0", range, it->iter_kind, IteratorAnnotation::kNone)); + // Reverse the Iterator array + Array temp(outs.rbegin(), outs.rend()); + outs = std::move(temp); + } else { + outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_kind, + IteratorAnnotation::kNone)); + } + + Array new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); + new_iters.insert(new_iters.end(), outs.begin(), outs.end()); + new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end()); + + StateNode* pstate = CopyOnWrite(); + pstate->stages.Set(stage_id, + Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); + pstate->concrete &= concrete; + + return outs; +} + +Array State::DoSplitStep(const SplitStep& step) { + return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, step->inner_to_outer); +} + +Iterator State::DoFuseStep(const FuseStep& step) { + int stage_id = step->stage_id; + const Stage& stage = operator->()->stages[stage_id]; + + String new_name; + PrimExpr new_extent = 1; + IteratorKind new_iter_kind = IteratorKind::kSpecial; + + for (size_t i = 0; i < step->fused_ids.size(); ++i) { + if (i > 0) { + CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1); + } + + const Iterator& it = stage->iters[step->fused_ids[i]]; + new_name = new_name + it->name + "@"; + + if (it->range.defined() && new_extent.defined()) { + new_extent = new_extent * it->range->extent; + } else { + new_extent = PrimExpr(); + } + + if (i == 0) { + new_iter_kind = it->iter_kind; + } else { + if (new_iter_kind != it->iter_kind) { + new_iter_kind = IteratorKind::kMixed; + } + } + } + + Range range; + if (new_extent.defined()) { + range = Range::FromMinExtent(0, new_extent); + } + Iterator new_it = Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone); + Array new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), + stage->iters.begin() + step->fused_ids.front()); + new_iters.push_back(new_it); + new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1, + stage->iters.end()); + + StateNode* pstate = CopyOnWrite(); + pstate->stages.Set(stage_id, + Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); + + return new_it; +} + +void State::DoSteps(const ComputeDAG& dag) { + CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages."; + + for (const auto& step : operator->()->transform_steps) { + if (auto ps = step.as()) { + DoReorderStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoSplitStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoFuseStep(GetRef(ps)); + } else { + LOG(FATAL) << "Invalid step: " << step; + } + } +} + +static const char* IteratorAnnotationString[] = { + "for", // kNone = 0 + "unroll", // kUnroll = 1 + "vectorize", // kVectorize = 2 + "parallel", // kParallel = 3 + "vthread", // kVThread = 4 + "gpu.blockIdx.x", // kBlockX = 5 + "gpu.threadIdx.x", // kThreadX = 6 + "gpu.blockIdx.y", // kBlockY = 7 + "gpu.threadIdx.y", // kThreadY = 8 + "tensorize" // kTensorized = 9 +}; + +// Print stage to ostream +void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent, + bool delete_trivial_loop) { + const Stage& stage = state->stages[stage_id]; + + if (stage->attrs.auto_unroll_max_step != 0) { + for (size_t j = 0; j < base_indent; ++j) { + *os << " "; + } + *os << stage->op->name << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n"; + } + if (stage->attrs.storage_offset != 0) { + for (size_t j = 0; j < base_indent; ++j) { + *os << " "; + } + *os << stage->op->name << " storage_offset: " << stage->attrs.storage_offset << "\n"; + } + + size_t indent = 0; + for (size_t i = 0; i < stage->iters.size(); ++i) { + const Iterator& iter = stage->iters[i]; + + if (!(delete_trivial_loop && iter->range.defined() && is_one(iter->range->extent))) { + for (size_t j = 0; j < base_indent + indent; ++j) { + *os << " "; + } + *os << IteratorAnnotationString[static_cast(iter->annotation)] << " "; + if (iter->range.defined()) { + *os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")"; + } else { + *os << iter->name << " (None)"; + } + *os << "\n"; + + indent += 2; + } + } + + for (size_t j = 0; j < base_indent + indent; ++j) { + *os << " "; + } + *os << stage->op->name << " = ...\n"; +} + +// Print state to ostream +void PrintState(std::ostream* os, const State& state, bool delete_trivial_loop) { + // Gather placeholders + Array placeholders; + for (const auto& stage : state->stages) { + if (stage->op_type == StageKind::kPlaceholder) { + placeholders.push_back(stage->op->name); + } + } + + *os << "Placeholder: "; + for (size_t i = 0; i < placeholders.size(); ++i) { + *os << placeholders[i]; + if (i != placeholders.size() - 1) { + *os << ", "; + } + } + *os << "\n"; + + // Print all stages + for (size_t i = 0; i < state->stages.size(); ++i) { + const Stage& stage = state->stages[i]; + if (stage->op_type == StageKind::kPlaceholder) { + continue; + } else if (stage->op_type == StageKind::kCompute) { + if (stage->compute_at == ComputeAtKind::kRoot) { + PrintStage(os, i, state, 0, delete_trivial_loop); + } + } else { + LOG(FATAL) << "Invalid op type"; + } + } +} + +String State::ToStr(bool delete_trivial_loop) const { + std::ostringstream os; + PrintState(&os, (*this), delete_trivial_loop); + return os.str(); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + PrintState(&p->stream, tvm::Downcast(ref), true); + }); + +/********** State interface API for ffi **********/ +TVM_REGISTER_GLOBAL("auto_schedule.StateReorder") + .set_body_typed([](State state, int stage_id, const Array& order) { + state.reorder(stage_id, order); + return state; + }); + +TVM_REGISTER_GLOBAL("auto_schedule.StateSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array>& lengths, bool inner_to_outer) { + const auto& res = state.split(stage_id, it, lengths, inner_to_outer); + return Array{state, res}; + }); + +TVM_REGISTER_GLOBAL("auto_schedule.StateFuse") + .set_body_typed([](State state, int stage_id, const Array& iters) { + const auto& res = state.fuse(stage_id, iters); + return Array{state, res}; + }); + +TVM_REGISTER_GLOBAL("auto_schedule.StateEqual").set_body_typed([](State state1, State state2) { + return std::equal_to()(state1, state2); +}); + +} // namespace auto_schedule +} // namespace tvm diff --git a/src/auto_schedule/loop_state.h b/src/auto_schedule/loop_state.h new file mode 100644 index 000000000000..5ba47b7263a1 --- /dev/null +++ b/src/auto_schedule/loop_state.h @@ -0,0 +1,382 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/loop_state.h + * \brief The definition of the "state" in search. + * + * Each LoopState corresponds to a schedule for its ComputeDAG. + * A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to + * construct the loop structure. + * The loop structure keeps a preview of how the schedule will finally look like after lowering the + * current state (e.g. number of iterators, the extent of each iterator, the compute_at locations + * ...). + * During the schedule search process, the loop structure can provide search policy with necessary + * information on how to manipulate the current state. + * The transform history is a sequence of `TransformStep` which will finally be mapped to TVM + * schedule primitives. The steps can also be used for the serialization of a state. + * + * The LoopState can be seen as a lightweight loop structure IR specifically for schedule search. + * We don't use the existing TVM IR but to extend a new structure on it is because: + * 1. We want fast incremental change to the loop structures. The search policy needs to get the + * immediate loop structures update rather than after TVM lowering; + * 2. We want serializable transform history for replay, backtracking, and mutation; + * 3. We may create some macro schedule primitives that represent the combination of several + * TVM schedule primitives. + * + * When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives. + * Since we share a lot of common objects during search, the transformation is implemented in + * copy on write style. All objects are immutable, which is similar to TVM IR. + */ + +#ifndef TVM_AUTO_SCHEDULE_LOOP_STATE_H_ +#define TVM_AUTO_SCHEDULE_LOOP_STATE_H_ + +#include + +#include + +#include "transform_step.h" + +namespace tvm { +namespace auto_schedule { + +using namespace tvm::tir; + +class ComputeDAG; + +/*! \brief The type of a stage. */ +enum class StageKind : int { + /*! \brief A placeholder stage. */ + kPlaceholder = 0, + /*! \brief A compute stage. */ + kCompute = 1 +}; + +/*! \brief The type of compute location. */ +enum class ComputeAtKind : int { + /*! \brief Compute at root. */ + kRoot = 0, + /*! \brief Compute inlined. */ + kInlined = 1, + /*! \brief Compute at some iterator. */ + kIter = 2, +}; + +/*! \brief The type of an iterator. */ +enum class IteratorKind : int { + /*! \brief Spatial iterator. */ + kSpatial = 0, + /*! \brief Reduction iterator. */ + kReduction = 1, + /*! \brief Fused spatial and reduction iterator. */ + kMixed = 2, + /*! \brief Special iterator. (e.g. virtual root iterator) */ + kSpecial = 3 +}; + +/*! \brief The type of an iterator's annotation. */ +enum class IteratorAnnotation : int { + /*! \brief This iterator has no annotation. */ + kNone = 0, + /*! \brief This iterator has been unrolled. */ + kUnroll = 1, + /*! \brief This iterator has been vectorized. */ + kVectorize = 2, + /*! \brief This iterator has been paralleld. */ + kParallel = 3, + /*! \brief This iterator has been bind to vthread. */ + kVThread = 4, + /*! \brief This iterator has been bind to blockIdx.x. */ + kBlockX = 5, + /*! \brief This iterator has been bind to threadIdx.x. */ + kThreadX = 6, + /*! \brief This iterator has been bind to blockIdx.y. */ + kBlockY = 7, + /*! \brief This iterator has been bind to threadIdx.y. */ + kThreadY = 8, + /*! \brief This iterator has been mapped with a tensorize intrinsic. */ + kTensorized = 9 +}; + +/*! + * \brief A for loop iterator + * Similar to tvm::IterVar in `include/tvm/tir/expr.h` + */ +class IteratorNode : public Object { + public: + /*! \brief The name of this iterator. */ + String name; + /*! \brief The range of this iterator. */ + Range range; + /*! \brief The iterator type of this iterator. */ + IteratorKind iter_kind; + /*! \brief The annotation type of this iterator. */ + IteratorAnnotation annotation; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("range", &range); + } + + static constexpr const char* _type_key = "auto_schedule.Iterator"; + TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); +}; + +/*! + * \brief Managed reference to IteratorNode. + * \sa IteratorNode + */ +class Iterator : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param name The name of this iterator. + * \param range The range of this iterator. + * \param iter_kind The iterator type of this iterator. + * \param annotation The annotation type of this iterator. + */ + Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation); + + TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); +}; + +/*! \brief Stage-level attributes. */ +struct StageAttributes { + /*! \brief The maximum steps for the pragma `auto_unroll_max_step`. */ + int auto_unroll_max_step; + /*! \brief The storage offset for the schedule primitive `storage_align`. */ + int storage_offset; +}; + +/*! + * \brief A op stage in the compute declaration. + * Similar to te::Stage in `include/schedule.h`. + */ +class StageNode : public Object { + public: + /*! \brief The operator of this stage */ + te::Operation op; + /*! \brief The type of this stage. */ + StageKind op_type; + /*! \brief The iterators in this stage. */ + Array iters; + /*! \brief The compute location of this stage. */ + ComputeAtKind compute_at; + /*! \brief Other stage-level attributes. */ + StageAttributes attrs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + v->Visit("iters", &iters); + } + + static constexpr const char* _type_key = "auto_schedule.Stage"; + TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); +}; + +/*! + * \brief Managed reference to StageNode. + * \sa StageNode + */ +class Stage : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param op A `te::Operation`. + */ + explicit Stage(te::Operation op); + /*! + * \brief The constructor. + * \param op A `te::Operation`. + * \param op_type The stage type of this op. + * \param iters The iterators of this op. + * \param compute_at The compute at type of this op. + * \param attrs Other stage-level attributes. + */ + Stage(te::Operation op, StageKind op_type, const Array& iters, ComputeAtKind compute_at, + StageAttributes attrs); + + TVM_DEFINE_OBJECT_REF_METHODS(Stage, ObjectRef, StageNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode); +}; + +/*! + * \brief A state in the search process. + * It consists of the current loop structure and a list of transformation steps used to construct + * it. + * Each State corresponds to a specific schedule for its ComputeDAG. + */ +class StateNode : public Object { + public: + /*! \brief Current stages and loop structures. */ + Array stages; + /*! \brief History transformation steps. */ + Array transform_steps; + /*! + * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all + * tile sizes of the state is filled. Only concrete state can be apply to TVM schedule. + */ + bool concrete; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("stages", &stages); + v->Visit("transform_steps", &transform_steps); + v->Visit("concrete", &concrete); + } + + static constexpr const char* _type_key = "auto_schedule.State"; + TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); + + private: + /*! + * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the + * stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep which Will be added + * later). + * The default value is an empty ObjectRef. (means no modification to the original DAG) + */ + ObjectRef current_compute_dag; +}; + +/*! + * \brief Managed reference to StateNode. + * \sa StateNode + */ +class State : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param ops `te::Operation`s for a compute declaration. + */ + explicit State(const Array& ops); + + /*! + * \brief Print the state to a human readable string. + * \param delete_trivial_loop True for skipping the trivial loops. + * (undefined or extent == 1, default set to True) + * \return The human readable state structure. + */ + String ToStr(bool delete_trivial_loop = true) const; + + /*! + * \brief General do step functions with a runtime dynamic dispatcher. This will re-apply all the + * transform steps with the initial state. + * \param dag The original ComputeDAG of this state. + * \note This is different from the class member `current_compute_dag`, for some transform step + * may change the op stage structure of the ComputeDAG. + */ + void DoSteps(const ComputeDAG& dag); + + /* Step APIs for State. */ + + /*! + * \brief Schedule primitive corresponds to te.reorder. + * \param stage_id The index of the stage to be reordered. + * \param order The expected iterator order. + */ + void reorder(int stage_id, const Array& order); + /*! + * \brief Schedule primitive corresponds to te.split. + * \param stage_id The index of the stage to be split. + * \param it The iterator the be split. + * \param lengths The multiple split factors. Can be None to be filled by search policy. + * \param inner_to_outer Whether the factor go from inner to outer, or from outer to inner. + * \return The iterator results after split. + */ + Array split(int stage_id, const Iterator& it, const Array>& lengths, + bool inner_to_outer = true); + /*! + * \brief Schedule primitive corresponds to te.fuse. + * \param stage_id The index of the stage to be fused. + * \param iters The iterators to be fused. + * \return The iterator result after fuse. + */ + Iterator fuse(int stage_id, const Array& iters); + + TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); + + private: + /* Do transform steps + * Note: The following functions only change loop state but do not change transform_history. + * We separate these functions out, so you can call them for replay easily given history steps */ + + /*! + * \brief Apply reorder step to current state. + * \param step A ReorderStep. + */ + void DoReorderStep(const ReorderStep& step); + /*! + * \brief Apply split step to current state. + * \param step A SplitStep. + * \return The iterator results after split. + */ + Array DoSplitStep(const SplitStep& step); + /*! + * \brief Apply fuse step to current state. + * \param step A FuseStep. + * \return The iterator result after fuse. + */ + Iterator DoFuseStep(const FuseStep& step); + + /*! + * \brief Common function for DoSplitStep and DoFollowSplitStep(Will be added later). + * \param stage_id The index of the stage to be split. + * \param iter_id The index of the iterator to be split. + * \param lengths The multiple split factors. + * \param inner_to_outer The split direction. + * \return The iterator results after split. + */ + Array DoSplitStepCommon(int stage_id, int iter_id, + const Array>& lengths, bool inner_to_outer); +}; + +} // namespace auto_schedule +} // namespace tvm + +// Hash and equal function for State +namespace std { + +/*! \brief The hash function for auto_schedule::State. */ +template <> +struct hash<::tvm::auto_schedule::State> { + std::size_t operator()(const ::tvm::auto_schedule::State& state) const { + return tvm::runtime::ObjectHash()(state.ToStr()); + } +}; + +/*! + * \brief The equal_to function for auto_schedule::State. + * We use the schedule result(its string format) of a state to check if two states are `euqal`. + * Equal States: 1. the transform steps are totally the same; 2. even with different steps, two + * states may still result in a same schedule. e.g. To split a axis with extent 512 to 3 parts + * [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can get a same result + * to split from outter to inner by factors [8, 16]) + */ +template <> +struct equal_to<::tvm::auto_schedule::State> { + bool operator()(const ::tvm::auto_schedule::State& lhs, + const ::tvm::auto_schedule::State& rhs) const { + return lhs.ToStr() == rhs.ToStr(); + } +}; + +} // namespace std + +#endif // TVM_AUTO_SCHEDULE_LOOP_STATE_H_ diff --git a/src/auto_schedule/measure.cc b/src/auto_schedule/measure.cc new file mode 100644 index 000000000000..b710745b02f9 --- /dev/null +++ b/src/auto_schedule/measure.cc @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/measure.cc + * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs. + */ + +#include "measure.h" + +#include + +#include + +#include "utils.h" + +namespace tvm { +namespace auto_schedule { + +TVM_REGISTER_NODE_TYPE(MeasureInputNode); +TVM_REGISTER_NODE_TYPE(BuildResultNode); +TVM_REGISTER_NODE_TYPE(MeasureResultNode); +TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); +TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode); +TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode); +TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode); +TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode); + +static const char* ErrorNoToStr[] = { + "NoError", + "InstantiationError", + "CompileHostError", + "CompileDeviceError", + "RuntimeDeviceError", + "WrongAnswerError", + "BuildTimeoutError", + "RunTimeoutError", + "UnknownError", +}; + +/********** Measure input and result **********/ +MeasureInput::MeasureInput(SearchTask task, State state) { + auto node = make_object(); + node->task = std::move(task); + node->state = std::move(state); + data_ = std::move(node); +} + +MeasureInput MeasureInputNode::copy() const { + auto node = make_object(); + node->task = task; + node->state = state; + return MeasureInput(node); +} + +BuildResult::BuildResult(String filename, Array args, int error_no, String error_msg, + double time_cost) { + auto node = make_object(); + node->filename = std::move(filename); + node->args = std::move(args); + node->error_no = error_no; + node->error_msg = std::move(error_msg); + node->time_cost = time_cost; + data_ = std::move(node); +} + +MeasureResult::MeasureResult(Array costs, int error_no, String error_msg, double all_cost, + double timestamp) { + auto node = make_object(); + node->costs = std::move(costs); + node->error_no = error_no; + node->error_msg = std::move(error_msg); + node->all_cost = all_cost; + node->timestamp = timestamp; + data_ = std::move(node); +} + +MeasureResult MeasureResultNode::copy() const { + auto node = make_object(); + node->costs = costs; + node->error_no = error_no; + node->error_msg = error_msg; + node->all_cost = all_cost; + node->timestamp = timestamp; + return MeasureResult(node); +} + +/********** LocalBuilder **********/ +LocalBuilder::LocalBuilder(int timeout, int n_parallel, const String& build_func) { + auto node = make_object(); + node->timeout = timeout; + node->n_parallel = n_parallel; + node->build_func = build_func; + data_ = std::move(node); +} + +Array LocalBuilderNode::Build(const Array& inputs, int verbose) { + if (const auto* f = runtime::Registry::Get("auto_schedule.local_builder.build")) { + Array results = (*f)(inputs, timeout, n_parallel, build_func, verbose); + return results; + } + LOG(FATAL) << "auto_schedule.local_builder.build is not registered. " + << "This is a function registered in Python, " + << "make sure the TVM Python runtime has been loaded successfully."; + throw; +} + +/********** LocalRunner **********/ +LocalRunner::LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, + double cooldown_interval) { + ObjectPtr node = make_object(); + node->timeout = timeout; + node->number = number; + node->repeat = repeat; + node->min_repeat_ms = min_repeat_ms; + node->cooldown_interval = cooldown_interval; + data_ = std::move(node); +} + +Array LocalRunnerNode::Run(const Array& inputs, + const Array& build_results, int verbose) { + if (const auto* f = runtime::Registry::Get("auto_schedule.local_runner.run")) { + Array results = (*f)(inputs, build_results, timeout, number, repeat, + min_repeat_ms, cooldown_interval, verbose); + return results; + } + LOG(FATAL) << "auto_schedule.local_runner.run is not registered. " + << "This is a function registered in Python, " + << "make sure the TVM Python runtime has been loaded successfully."; + throw; +} + +/********** ProgramMeasurer **********/ +ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, + Optional> callbacks, int verbose, + int max_continous_error) { + auto node = make_object(); + node->builder = std::move(builder); + node->runner = std::move(runner); + node->callbacks = std::move(callbacks); + node->verbose = verbose; + node->max_continous_error = max_continous_error < 0 + ? ProgramMeasurerNode::DEFAULT_MAX_CONTINOUS_ERROR + : max_continous_error; + data_ = std::move(node); +} + +void ProgramMeasurerNode::Reset() { + ct = error_ct = 0; + best_flops.clear(); + best_ct.clear(); + best_state.clear(); +} + +void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& policy, + const Array& inputs, Array* results, + int batch_size) { + results->clear(); + results->reserve(inputs.size()); + + if (batch_size == -1) { + // set default batch size + batch_size = builder->n_parallel * 2; + } + + StdCout(verbose) << "Get " << inputs.size() << " programs for measure. (This may take a while)" + << std::endl; + + for (size_t i = 0; i < inputs.size(); i += batch_size) { + Array input_batch(inputs.begin() + i, + inputs.begin() + std::min(i + batch_size, inputs.size())); + Array result_batch; + + // build and run + SilentMeasure(task, input_batch, &result_batch); + + // update current best state according to the new measure result + for (size_t j = 0; j < input_batch.size(); ++j) { + double flops; + if (result_batch[j]->error_no == 0) { + flops = task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs); + error_ct = 0; + } else { + flops = 0.0; + error_ct++; + } + + const String& workload_key = input_batch[j]->task->workload_key; + if (flops > best_flops[workload_key]) { + best_flops[workload_key] = flops; + best_state[workload_key] = input_batch[j]->state; + best_ct[workload_key] = ct; + } + + ct++; + StdCout(verbose) << std::fixed << std::setprecision(2) << Chars('=', 50) << "\n" + << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / " + << best_flops[workload_key] / 1e9 << "\tresults: " << result_batch[j] << "\n" + << Chars('=', 50) << "\n" + << input_batch[j]->state << "\n"; + } + + // Call callback functions + if (callbacks) { + for (const auto& callback : callbacks.value()) { + callback->Callback(policy, input_batch, result_batch); + } + } + + // Store result batch + for (auto& res : result_batch) { + results->push_back(res); + } + + if (error_ct > max_continous_error) { + LOG(FATAL) << "Too many errors happened during tuning"; + } + } +} + +void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, const Array& inputs, + Array* results) { + results->clear(); + results->reserve(inputs.size()); + + // Call builder and runner + Array build_res_batch = builder->Build(inputs, verbose); + Array result_batch = runner->Run(inputs, build_res_batch, verbose); + + // Store result batch + for (auto& res : result_batch) { + results->push_back(res); + } +} + +/********** Printing functions **********/ +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "MeasureInput()"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + if (node->error_no == static_cast(MeasureErrorNO::kNoError)) { + p->stream << "MeasureResult(cost:["; + auto old_config = p->stream.precision(4); + for (size_t i = 0; i < node->costs.size(); ++i) { + auto pf = node->costs[i].as(); + CHECK(pf != nullptr); + p->stream << pf->value; + if (i != node->costs.size() - 1) { + p->stream << ","; + } + } + p->stream.precision(old_config); + p->stream << "], "; + p->stream << "error_no:" << 0 << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } else { + p->stream << "MeasureResult(" + << "error_type:" << ErrorNoToStr[node->error_no] << ", " + << "error_msg:" << node->error_msg << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "BuildResult(" << node->filename << ", " << node->error_no << ", " + << node->time_cost << ")"; + }); + +/********** Measure interface API for ffi **********/ +TVM_REGISTER_GLOBAL("auto_schedule.MeasureInput").set_body_typed([](SearchTask task, State state) { + return MeasureInput(task, state); +}); + +TVM_REGISTER_GLOBAL("auto_schedule.BuildResult") + .set_body_typed([](String filename, Array args, int error_no, String error_msg, + double time_cost) { + return BuildResult(filename, args, error_no, error_msg, time_cost); + }); + +TVM_REGISTER_GLOBAL("auto_schedule.MeasureResult") + .set_body_typed([](Array costs, int error_no, String error_msg, double all_cost, + double timestamp) { + return MeasureResult(costs, error_no, error_msg, all_cost, timestamp); + }); + +TVM_REGISTER_GLOBAL("auto_schedule.ProgramBuilderBuild") + .set_body_typed([](const ProgramBuilder& builder, const Array& inputs, + int verbose) { return builder->Build(inputs, verbose); }); + +TVM_REGISTER_GLOBAL("auto_schedule.ProgramRunnerRun") + .set_body_typed([](const ProgramRunner& runner, const Array& inputs, + const Array& build_results, + int verbose) { return runner->Run(inputs, build_results, verbose); }); + +TVM_REGISTER_GLOBAL("auto_schedule.LocalBuilder") + .set_body_typed([](int timeout, int n_parallel, const String& build_func) { + return LocalBuilder(timeout, n_parallel, build_func); + }); + +TVM_REGISTER_GLOBAL("auto_schedule.LocalRunner") + .set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms, + double cooldown_interval) { + return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval); + }); + +} // namespace auto_schedule +} // namespace tvm diff --git a/src/auto_schedule/measure.h b/src/auto_schedule/measure.h new file mode 100644 index 000000000000..a7890eaffd0d --- /dev/null +++ b/src/auto_schedule/measure.h @@ -0,0 +1,448 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/measure.h + * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs. + * These functions are responsible for building the tvm module, uploading it to remote devices, + * recording the running time costs, and checking the correctness of the output. + * + * We separate the measurement into two steps: build and run. + * A builder builds the executable binary files and a runner runs the binary files to get the + * measurement results. The flow of data structures is + * + * `ProgramBuilder` `ProgramRunner` + * `MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult` + * + * We implement these in python to utilize python's multiprocessing and error handling. + */ + +#ifndef TVM_AUTO_SCHEDULE_MEASURE_H_ +#define TVM_AUTO_SCHEDULE_MEASURE_H_ + +#include +#include +#include + +#include "loop_state.h" +#include "search_task.h" + +namespace tvm { +namespace auto_schedule { + +class SearchPolicy; +class MeasureInput; +class MeasureResult; + +/*! \brief The error code of one measurement */ +enum class MeasureErrorNO : int { + /*! \brief No error. */ + kNoError = 0, + /*! \brief Errors happen when apply transform steps from init state. */ + kInstantiationError = 1, + /*! \brief Errors happen when compiling code on host. (when build module) */ + kCompileHostError = 2, + /*! \brief Errors happen when compiling code on device. (when load module) */ + kCompileDeviceError = 3, + /*! \brief Errors happen when run program on device. */ + kRuntimeDeviceError = 4, + /*! \brief Answer is wrong when compared to a reference output. */ + kWrongAnswerError = 5, + /*! \brief Timeout during compilation. */ + kBuildTimeoutError = 6, + /*! \brief Timeout during run. */ + kRunTimeoutError = 7, + /*! \brief Unknown error. */ + kUnknonwError = 8, +}; + +// Inputs and results of one measurement + +/*! \brief Store the input of a measurement */ +class MeasureInputNode : public Object { + public: + /*! \brief The search task. */ + SearchTask task; + /*! \brief The program state to be measured. */ + State state; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("task", &task); + v->Visit("state", &state); + } + + /*! \brief Do shallow copy. */ + MeasureInput copy() const; + + static constexpr const char* _type_key = "auto_schedule.MeasureInput"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureInputNode, Object); +}; + +/*! + * \brief Managed reference to MeasureInputNode. + * \sa MeasureInputNode + */ +class MeasureInput : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param task The SearchTeask of this measure. + * \param state The State to be measured. + */ + MeasureInput(SearchTask task, State state); + + TVM_DEFINE_OBJECT_REF_METHODS(MeasureInput, ObjectRef, MeasureInputNode); +}; + +/*! \brief Store the result of a build. */ +class BuildResultNode : public Object { + public: + /*! \brief The filename of built binary file. */ + String filename; + /*! \brief The arguments. */ + Array args; + /*! \brief The error code. (0 means no error, see MeasureErrorNO) */ + int error_no; + /*! \brief The error message if there is any error. */ + String error_msg; + /*! \brief The time cost of build. */ + double time_cost; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("filename", &filename); + v->Visit("args", &args); + v->Visit("error_no", &error_no); + v->Visit("error_msg", &error_msg); + v->Visit("time_cost", &time_cost); + } + + static constexpr const char* _type_key = "auto_schedule.BuildResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(BuildResultNode, Object); +}; + +/*! + * \brief Managed reference to BuildResultNode. + * \sa BuildResultNode + */ +class BuildResult : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param filename The filename of built binary file. + * \param args The arguments. + * \param error_no The error code. + * \param error_msg The error message if there is any error. + * \param time_cost The time cost of build. + */ + BuildResult(String filename, Array args, int error_no, String error_msg, + double time_cost); + TVM_DEFINE_OBJECT_REF_METHODS(BuildResult, ObjectRef, BuildResultNode); +}; + +/*! \brief Store the results of a measurement. */ +class MeasureResultNode : public Object { + public: + /*! \brief The time costs of execution. */ + Array costs; + /*! \brief The error code. (0 means no error, see MeasureErrorNO) */ + int error_no; + /*! \brief The error message if there is any error. */ + String error_msg; + /*! \brief The time cost of build and run. */ + double all_cost; + /*! \brief The time stamps of this measurement. */ + double timestamp; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("costs", &costs); + v->Visit("error_no", &error_no); + v->Visit("error_msg", &error_msg); + v->Visit("all_cost", &all_cost); + v->Visit("timestamp", ×tamp); + } + + /*! \brief Do shallow copy. */ + MeasureResult copy() const; + + static constexpr const char* _type_key = "auto_schedule.MeasureResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureResultNode, Object); +}; + +/*! + * \brief Managed reference to MeasureResultNode. + * \sa MeasureResultNode + */ +class MeasureResult : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param costs The time costs of execution. + * \param error_no The error code. + * \param error_msg The error message if there is any error. + * \param all_cost The time cost of build and run. + * \param timestamp The time stamps of this measurement. + */ + MeasureResult(Array costs, int error_no, String error_msg, double all_cost, + double timestamp); + + TVM_DEFINE_OBJECT_REF_METHODS(MeasureResult, ObjectRef, MeasureResultNode); +}; + +/*! \brief Bass class of measurement callbacks */ +class MeasureCallbackNode : public Object { + public: + /*! + * \brief Callback function that will be called on measurement input/result pairs + * after measurement. + * \param policy The current search policy. + * \param inputs An Array of MeasureInput. + * \param results An Array of MeasureResult. + */ + virtual void Callback(const SearchPolicy& policy, const Array& inputs, + const Array& results) = 0; + static constexpr const char* _type_key = "auto_schedule.MeasureCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); +}; + +/*! + * \brief Managed reference to MeasureCallbackNode. + * \sa MeasureCallbackNode + */ +class MeasureCallback : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode); +}; + +// The base class of ProgramBuilders and ProgramRunners. + +/*! \brief ProgramBuilder that builds the programs */ +class ProgramBuilderNode : public Object { + public: + /*! \brief The number of tasks to run in parallel */ + int n_parallel; + /*! \brief Timeout of a build */ + int timeout; + + /*! + * \brief Build programs and return results. + * \param inputs An Array of MeasureInput. + * \param verbose Verbosity level. 0 for silent, 1 to output information during program + * building. + * \return An Array of MeasureResult. + */ + virtual Array Build(const Array& inputs, int verbose) = 0; + + static constexpr const char* _type_key = "auto_schedule.ProgramBuilder"; + TVM_DECLARE_BASE_OBJECT_INFO(ProgramBuilderNode, Object); +}; + +/*! + * \brief Managed reference to ProgramBuilderNode. + * \sa ProgramBuilderNode + */ +class ProgramBuilder : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramBuilder, ObjectRef, ProgramBuilderNode); +}; + +/*! \brief ProgramRunner that runs the built programs and measure the time cost. */ +class ProgramRunnerNode : public Object { + public: + /*! \brief Timeout of a run. */ + int timeout; + + /*! + * \brief Run measurement and return results. + * \param inputs An Array of MeasureInput. + * \param build_results An Array of BuildResult. + * \param verbose Verbosity level. 0 for silent, 1 to output information during program + * running. + * \return An Array of MeasureResult. + */ + virtual Array Run(const Array& inputs, + const Array& build_results, int verbose) = 0; + + static constexpr const char* _type_key = "auto_schedule.ProgramRunner"; + TVM_DECLARE_BASE_OBJECT_INFO(ProgramRunnerNode, Object); +}; + +/*! + * \brief Managed reference to ProgramRunnerNode. + * \sa ProgramRunnerNode + */ +class ProgramRunner : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramRunner, ObjectRef, ProgramRunnerNode); +}; + +// Implementation of various builders and runners + +/*! \brief LocalBuilder use local CPU cores to build programs in parallel */ +class LocalBuilderNode : public ProgramBuilderNode { + public: + /*! \brief Build function. */ + String build_func; + + Array Build(const Array& inputs, int verbose) final; + + static constexpr const char* _type_key = "auto_schedule.LocalBuilder"; + TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, ProgramBuilderNode); +}; + +/*! + * \brief Managed reference to LocalBuilderNode. + * \sa LocalBuilderNode + */ +class LocalBuilder : public ProgramBuilder { + public: + /*! + * \brief The constructor. + * \param timeout The timeout limit (in second) for each build thread. + * This will be used in a wrapper of the multiprocessing.Process.join(). + * \param n_parallel Number of threads used to build in parallel. + * \param build_func The name of registered build function. + */ + LocalBuilder(int timeout, int n_parallel, const String& build_func); + + TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, ProgramBuilder, LocalBuilderNode); +}; + +/*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ +class LocalRunnerNode : public ProgramRunnerNode { + public: + /*! \brief Number of measure times. */ + int number; + /*! \brief Number of repeat times in each measure. */ + int repeat; + /*! \brief The minimum duration of one repeat in milliseconds. */ + int min_repeat_ms; + /*! \brief The cool down interval between two measurements. */ + double cooldown_interval; + + Array Run(const Array& inputs, + const Array& build_results, int verbose) final; + + static constexpr const char* _type_key = "auto_schedule.LocalRunner"; + TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, ProgramRunnerNode); +}; + +/*! + * \brief Managed reference to LocalRunnerNode. + * \sa LocalRunnerNode + */ +class LocalRunner : public ProgramRunner { + public: + /*! + * \brief The constructor. See the corresponding class in python/tvm/auto_schedule/measure.py + * for more detailed parameter explaination. + * \param timeout The timeout limit (in second) for each run. + * This is used in a wrapper of the multiprocessing.Process.join(). + * \param number Number of measure times. + * \param repeat Number of repeat times in each measure. + * \param min_repeat_ms The minimum duration of one repeat in milliseconds. + * \param cooldown_interval The cool down interval between two measurements. + */ + LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LocalRunner, ProgramRunner, LocalRunnerNode); +}; + +/*! + * \brief Measurer that measures the time costs of tvm programs + * This class combines ProgramBuilder and ProgramRunner, and provides a simpler API */ +class ProgramMeasurerNode : public Object { + public: + /*! \brief Measured programs counter. */ + int ct; + /*! \brief Continuous error counter. */ + int error_ct; + /*! \brief Workload key to best flops map. */ + std::unordered_map best_flops; + /*! \brief Workload key to best state map. */ + std::unordered_map best_state; + /*! \brief Workload key to best state's count index map. */ + std::unordered_map best_ct; + /*! \brief The ProgramBuilder to build each program. */ + ProgramBuilder builder; + /*! \brief The ProgramRunner to measure each program. */ + ProgramRunner runner; + /*! \brief MeasureCallback to be called after each measure batch. */ + Optional> callbacks; + /*! \brief Verbosity level. 0 for silent, 1 to output information during program measuring. */ + int verbose; + /*! \brief The number of max continuous error. */ + int max_continous_error; + + /*! \brief Reset book keeping variables */ + void Reset(); + + /*! + * \brief Do measurement. + * \param task The current SearchTask. + * \param policy The current SearchPolicy. + * \param inputs The MeasureInputs. + * \param results A pointer to a MeasureResult Array, this is used as output. + * \param batch_size Number of programs to be measured in one batch. + */ + void Measure(const SearchTask& task, const SearchPolicy& policy, + const Array& inputs, Array* results, + int batch_size = -1); + /*! + * \brief Do measurement silently. + * This API will not print the measure results to screen. + * \param task The current SearchTask. + * \param inputs The MeasureInputs. + * \param results A pointer to a MeasureResult Array, this is used as output. + */ + void SilentMeasure(const SearchTask& task, const Array& inputs, + Array* results); + + /*! \brief The default max continuous error setting. */ + static const int DEFAULT_MAX_CONTINOUS_ERROR = 150; + + static constexpr const char* _type_key = "auto_schedule.ProgramMeasurer"; + TVM_DECLARE_FINAL_OBJECT_INFO(ProgramMeasurerNode, Object); +}; + +/*! + * \brief Managed reference to ProgramMeasurerNode. + * \sa ProgramMeasurerNode + */ +class ProgramMeasurer : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param builder The ProgramBuilder to build each program. + * \param runner The ProgramRunner to measure each program. + * \param callbacks MeasureCallback to be called after each measure batch. + * \param verbose Verbosity level. 0 for silent, 1 to output information during program + * measuring. + * \param max_continous_error The number of max continuous error. + */ + ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, + Optional> callbacks, int verbose, + int max_continous_error = -1); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramMeasurer, ObjectRef, ProgramMeasurerNode); +}; + +} // namespace auto_schedule +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULE_MEASURE_H_ diff --git a/src/auto_schedule/measure_record.cc b/src/auto_schedule/measure_record.cc new file mode 100644 index 000000000000..99bd5917f7c8 --- /dev/null +++ b/src/auto_schedule/measure_record.cc @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/measure_record.cc + * \brief Json serialization format for dumping and loading tuning records. + */ + +#include "measure_record.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "loop_state.h" +#include "transform_step.h" +#include "utils.h" + +// Json serialization handler for MeasureInput, MeasureResult +// (and recursively for SearchTask, State, Step, ...) +namespace dmlc { +namespace json { + +inline std::vector IntArrayToVector(const ::tvm::Array<::tvm::Integer>& data) { + std::vector out; + for (const auto& x : data) { + CHECK(x.defined()); + out.push_back(x); + } + return out; +} + +inline std::vector IntArrayToVector( + const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& data) { + std::vector out; + for (const auto& x : data) { + CHECK(x); + out.push_back(x.value()); + } + return out; +} + +template <> +struct Handler<::tvm::Array<::tvm::auto_schedule::Stage>> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::Array<::tvm::auto_schedule::Stage>& data) { + writer->BeginArray(false); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + ::tvm::Array<::tvm::auto_schedule::Stage>* data) { + bool s; + reader->BeginArray(); + s = reader->NextArrayItem(); + CHECK(!s); + } +}; + +template <> +struct Handler<::tvm::Array<::tvm::auto_schedule::Step>> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::Array<::tvm::auto_schedule::Step>& data) { + writer->BeginArray(false); + for (size_t i = 0; i < data.size(); ++i) { + writer->WriteArraySeperator(); + writer->BeginArray(false); + if (auto ps = data[i].as<::tvm::auto_schedule::ReorderStepNode>()) { + writer->WriteArrayItem(std::string("RE")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(IntArrayToVector(ps->after_ids)); + } else if (auto ps = data[i].as<::tvm::auto_schedule::SplitStepNode>()) { + writer->WriteArrayItem(std::string("SP")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(ps->extent ? ::tvm::auto_schedule::GetIntImm(ps->extent.value()) + : 0); + writer->WriteArrayItem(IntArrayToVector(ps->lengths)); + writer->WriteArrayItem(static_cast(ps->inner_to_outer)); + } else if (auto ps = data[i].as<::tvm::auto_schedule::FuseStepNode>()) { + writer->WriteArrayItem(std::string("FU")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(IntArrayToVector(ps->fused_ids)); + } else { + LOG(FATAL) << "Invalid step: " << data[i]; + } + writer->EndArray(); + } + writer->EndArray(); + } + + inline static void Read(dmlc::JSONReader* reader, + ::tvm::Array<::tvm::auto_schedule::Step>* data) { + std::vector int_list; + bool s, inner_to_outer; + std::string name, scope_name, pragma_type, ti_func_name; + int stage_id, iter_id, extent; + + reader->BeginArray(); + data->clear(); + while (reader->NextArrayItem()) { + reader->BeginArray(); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&name); + if (name == "RE") { + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&int_list); + ::tvm::Array<::tvm::Integer> after_ids; + for (const auto& i : int_list) { + after_ids.push_back(i); + } + data->push_back(::tvm::auto_schedule::ReorderStep(stage_id, after_ids)); + } else if (name == "SP") { + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&extent); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&int_list); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&inner_to_outer); + ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths; + for (const auto& i : int_list) { + lengths.push_back(::tvm::Integer(i)); + } + data->push_back(::tvm::auto_schedule::SplitStep( + stage_id, iter_id, extent == 0 ? ::tvm::PrimExpr() : extent, lengths, inner_to_outer)); + } else if (name == "FU") { + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&int_list); + ::tvm::Array<::tvm::Integer> fused_ids; + for (const auto& i : int_list) { + fused_ids.push_back(i); + } + data->push_back(::tvm::auto_schedule::FuseStep(stage_id, fused_ids)); + } else { + LOG(FATAL) << "Invalid step format"; + } + s = reader->NextArrayItem(); + CHECK(!s); + } + } +}; + +template <> +struct Handler<::tvm::auto_schedule::StateNode> { + inline static void Write(dmlc::JSONWriter* writer, const ::tvm::auto_schedule::StateNode& data) { + writer->BeginArray(false); + writer->WriteArrayItem(data.stages); + writer->WriteArrayItem(data.transform_steps); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_schedule::StateNode* data) { + reader->BeginArray(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&data->stages); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&data->transform_steps); + s = reader->NextArrayItem(); + CHECK(!s); + } +}; + +template <> +struct Handler<::tvm::auto_schedule::SearchTaskNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::auto_schedule::SearchTaskNode& data) { + writer->BeginArray(false); + writer->WriteArrayItem(std::string(data.workload_key)); + writer->WriteArrayItem(data.target->str()); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_schedule::SearchTaskNode* data) { + std::string target_str; + bool s; + + reader->BeginArray(); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&target_str); + data->workload_key = std::move(target_str); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&target_str); + data->target = ::tvm::Target::Create(target_str); + s = reader->NextArrayItem(); + CHECK(!s); + } +}; + +template <> +struct Handler<::tvm::auto_schedule::MeasureInputNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::auto_schedule::MeasureInputNode& data) { + writer->BeginArray(false); + writer->WriteArrayItem(*data.task.operator->()); + writer->WriteArrayItem(*data.state.operator->()); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_schedule::MeasureInputNode* data) { + bool s; + auto task_node = ::tvm::make_object<::tvm::auto_schedule::SearchTaskNode>(); + auto state_node = ::tvm::make_object<::tvm::auto_schedule::StateNode>(); + state_node->concrete = true; + + reader->BeginArray(); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(task_node.get()); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(state_node.get()); + s = reader->NextArrayItem(); + CHECK(!s); + + data->task = ::tvm::auto_schedule::SearchTask(task_node); + data->state = ::tvm::auto_schedule::State(state_node); + } +}; + +template <> +struct Handler<::tvm::auto_schedule::MeasureResultNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::auto_schedule::MeasureResultNode& data) { + writer->BeginArray(false); + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (const auto& x : data.costs) { + auto pf = x.as<::tvm::tir::FloatImmNode>(); + CHECK(pf != nullptr) << "Cost can only contain float values"; + writer->WriteArrayItem(pf->value); + } + writer->EndArray(); + writer->WriteArrayItem(data.error_no); + writer->WriteArrayItem(data.all_cost); + writer->WriteArrayItem(static_cast((data.timestamp))); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_schedule::MeasureResultNode* data) { + bool s; + std::vector tmp; + + reader->BeginArray(); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&tmp); + data->costs.clear(); + for (const auto& i : tmp) { + data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i)); + } + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&data->error_no); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&data->all_cost); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&data->timestamp); + s = reader->NextArrayItem(); + CHECK(!s); + } +}; + +} // namespace json +} // namespace dmlc + +namespace tvm { +namespace auto_schedule { + +TVM_REGISTER_OBJECT_TYPE(RecordToFileNode); +TVM_REGISTER_OBJECT_TYPE(RecordReaderNode); + +const std::string AUTO_SCHEDULE_LOG_VERSION = "v0.2"; // NOLINT(*) + +RecordToFile::RecordToFile(String filename) { + auto node = make_object(); + node->filename = std::move(filename); + data_ = std::move(node); +} + +void WriteMeasureRecords(std::ostream* os, const Array& inputs, + const Array& results) { + dmlc::JSONWriter writer(os); + for (size_t i = 0; i < inputs.size(); ++i) { + writer.BeginObject(false); + writer.WriteObjectKeyValue("i", *inputs[i].operator->()); + writer.WriteObjectKeyValue("r", *results[i].operator->()); + writer.WriteObjectKeyValue("v", AUTO_SCHEDULE_LOG_VERSION); + writer.EndObject(); + *os << "\n"; + } +} + +void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res, + std::string* log_version) { + std::istringstream ss(str); + dmlc::JSONReader reader(&ss); + std::string key; + + reader.BeginObject(); + while (reader.NextObjectItem(&key)) { + if (key == "i") { + reader.Read(inp); + } else if (key == "r") { + reader.Read(res); + } else if (key == "v") { + reader.Read(log_version); + } else { + LOG(FATAL) << "Invalid key in json log: " << key; + } + } +} + +void RecordToFileNode::Callback(const SearchPolicy& policy, const Array& inputs, + const Array& results) { + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, inputs, results); +} + +RecordReader::RecordReader(String filename) { + auto node = make_object(); + node->filename = filename; + node->infile.open(filename, std::ifstream::in); + data_ = std::move(node); +} + +RecordReaderNode::~RecordReaderNode() { infile.close(); } + +bool RecordReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { + std::string log_version; + + while (std::getline(infile, cur_line_)) { + if (cur_line_[0] == '#' || cur_line_[0] == ' ') { + // skip comment lines begin with '#' or ' ' + continue; + } + ReadMeasureRecord(cur_line_, inp, res, &log_version); + return true; + } + + return false; +} + +std::pair, Array> RecordReaderNode::ReadLines(int max_size, + int skip_size) { + auto inp = make_object(); + auto res = make_object(); + Array inputs; + Array results; + + while (ReadNext(inp.get(), res.get())) { + if (skip_size > 0) { + skip_size--; + continue; + } + + inputs.push_back(inp->copy()); + results.push_back(res->copy()); + + if (max_size > 0 && static_cast(inputs.size()) >= max_size) { + break; + } + } + + return std::make_pair(inputs, results); +} + +TVM_REGISTER_GLOBAL("auto_schedule.RecordToFile").set_body_typed([](const String& filename) { + return RecordToFile(filename); +}); + +TVM_REGISTER_GLOBAL("auto_schedule.RecordReader").set_body_typed([](const String& filename) { + return RecordReader(filename); +}); + +TVM_REGISTER_GLOBAL("auto_schedule.RecordReaderReadLines") + .set_body_typed([](RecordReader reader, int size, int skip_size) { + const auto& res = reader->ReadLines(size, skip_size); + return Array{res.first, res.second}; + }); + +TVM_REGISTER_GLOBAL("auto_schedule.RecordReaderReadNext").set_body_typed([](RecordReader reader) { + auto inp = make_object(); + auto res = make_object(); + if (reader->ReadNext(inp.get(), res.get())) { + return Array{ObjectRef(inp), ObjectRef(res)}; + } else { + return Array(); + } +}); + +TVM_REGISTER_GLOBAL("auto_schedule.SaveRecords") + .set_body_typed([](String filename, Array in, Array res) { + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, in, res); + }); +} // namespace auto_schedule +} // namespace tvm diff --git a/src/auto_schedule/measure_record.h b/src/auto_schedule/measure_record.h new file mode 100644 index 000000000000..f97e30ae9f70 --- /dev/null +++ b/src/auto_schedule/measure_record.h @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/measure_record.h + * \brief Json serialization format for dumping and loading tuning records. + */ + +#ifndef TVM_AUTO_SCHEDULE_MEASURE_RECORD_H_ +#define TVM_AUTO_SCHEDULE_MEASURE_RECORD_H_ + +#include +#include +#include + +#include "measure.h" + +namespace tvm { +namespace auto_schedule { + +/*! \brief Callback for logging the input and results of measurements to file */ +class RecordToFileNode : public MeasureCallbackNode { + public: + /*! \brief File name for this callback to write log to. */ + String filename; + + void Callback(const SearchPolicy& policy, const Array& inputs, + const Array& results) final; + + static constexpr const char* _type_key = "auto_schedule.RecordToFile"; + TVM_DECLARE_FINAL_OBJECT_INFO(RecordToFileNode, MeasureCallbackNode); +}; + +/*! + * \brief Managed reference to RecordToFileNode. + * \sa RecordToFileNode + */ +class RecordToFile : public MeasureCallback { + public: + /*! + * \brief The constructor. + * \param filename File name for this callback to write log. + */ + explicit RecordToFile(String filename); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RecordToFile, MeasureCallback, RecordToFileNode); +}; + +/*! \brief Log reader to load step logs from a file.*/ +class RecordReaderNode : public Object { + public: + /*! \brief File name for this reader to load log from. */ + String filename; + /*! \brief The reading file stream. */ + std::ifstream infile; + + ~RecordReaderNode(); + + /*! + * \brief Read next line in the log file. + * \param inp A pointer to a MeasureInputNode, this is used as output. + * \param res A pointer to a MeasureResultNode, this is used as output. + * \return Whether the read is successful. */ + bool ReadNext(MeasureInputNode* inp, MeasureResultNode* res); + + /*! + * \brief Read multiple lines from the log file. + * \param max_size The maximum number of lines. -1 means read all lines. + * \param skip_size Skip the first n lines. + * \return The MeasureInputs and MeasureResults loaded from the log file. + */ + std::pair, Array> ReadLines(int max_size = -1, + int skip_size = 0); + + static constexpr const char* _type_key = "auto_schedule.RecordReader"; + TVM_DECLARE_FINAL_OBJECT_INFO(RecordReaderNode, Object); + + private: + /*! \brief A string object to store the next line. */ + std::string cur_line_; +}; + +/*! + * \brief Managed reference to RecordReaderNode. + * \sa RecordReaderNode + */ +class RecordReader : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param filename File name for this callback to write log. + */ + explicit RecordReader(String filename); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RecordReader, ObjectRef, RecordReaderNode); +}; + +/*! + * \brief Write measure records to an output stream. + * \param os A pointer to a output stream. + * \param inputs The MeasureInputs to be written. + * \param results The MeasureResults to be written. + */ +void WriteMeasureRecords(std::ostream* os, const Array& inputs, + const Array& results); + +/*! + * \brief Read one measure record from a string. + * \param str The record string to be extract. + * \param inp A pointer to a MeasureInputNode, this is used as output. + * \param res A pointer to a MeasureResultNode, this is used as output. + * \param log_version A pointer to a log version string. + */ +void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res, + std::string* log_version); + +} // namespace auto_schedule +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULE_MEASURE_RECORD_H_ diff --git a/src/auto_schedule/search_policy/empty_policy.cc b/src/auto_schedule/search_policy/empty_policy.cc new file mode 100644 index 000000000000..d91d563252b5 --- /dev/null +++ b/src/auto_schedule/search_policy/empty_policy.cc @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/search_policy/empty_policy.cc + * \brief This is an brief example of search policy. + */ + +#include "empty_policy.h" + +#include + +#include "../measure.h" + +namespace tvm { +namespace auto_schedule { + +TVM_REGISTER_NODE_TYPE(EmptyPolicyNode); + +State EmptyPolicyNode::Search(SearchTask task, int num_measure_trials, int early_stopping, + int num_measures_per_round, int verbose, ProgramMeasurer measurer, + Optional> pre_search_callbacks) { + cur_task = task; + + // Run pre_search_callbacks before the search process + // This Interface is usually used to set some init status + RunCallbacks(pre_search_callbacks); + + // Basic design principe: `SearchOneRound()` several times to get candidate states, + // measure them and return the best one + // Measure is disabled if num_measure_trials <= 1 + if (num_measure_trials <= 1) { + const auto& res = SearchOneRound(); + CHECK_GT(res.size(), 0); + + return res[0]; + } else { + Array inputs; + Array results; + + measurer->Reset(); + int ct = 0; + // In each round, we call SearchOneRound to get several candidate states, + // then use ProgramMeasurer to test their performance + while (ct < num_measure_trials) { + const auto& res = SearchOneRound(); + ct += res.size(); + // Build MeasureInputs for measuring + inputs.clear(); + for (const auto& state : res) { + // The class members measured_states_set_ provided by SearchPolicy can be used to filter + // out the already measured states + inputs.push_back(MeasureInput(cur_task, state)); + } + // ProgramMeasurer will record the state with best performance during measure process + measurer->Measure(cur_task, GetRef(this), inputs, &results); + } + + // Return a state with best measured performance + return measurer->best_state[cur_task->workload_key]; + } +} + +// As an example policy, EmptyPolicy always returns a init state +Array EmptyPolicyNode::SearchOneRound() { + Array res; + + // 1. We will process `Program sampling` first to generate several initial schedules + res.push_back(cur_task->compute_dag->init_state); + + // 2. Then `Performance Tuning`: use cost model and evolutionary search to seek for the schedule + // with best performance + // Note: This example policy does not include this part + + // 3. The returned candidate schedules will be measured in hardware + return res; +} + +TVM_REGISTER_GLOBAL("auto_schedule.EmptyPolicy").set_body_typed([]() { + return EmptyPolicy(make_object()); +}); + +} // namespace auto_schedule +} // namespace tvm diff --git a/src/auto_schedule/search_policy/empty_policy.h b/src/auto_schedule/search_policy/empty_policy.h new file mode 100644 index 000000000000..610a02a3cd12 --- /dev/null +++ b/src/auto_schedule/search_policy/empty_policy.h @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/search_policy/empty_policy.h + * \brief A brief example of the search policy which always returns the initial naive schedule + * (state). + */ + +#ifndef TVM_AUTO_SCHEDULE_SEARCH_POLICY_EMPTY_POLICY_H_ +#define TVM_AUTO_SCHEDULE_SEARCH_POLICY_EMPTY_POLICY_H_ + +#include "../loop_state.h" +#include "search_policy.h" + +namespace tvm { +namespace auto_schedule { + +/*! + * \brief A brief example of the search policy which always returns the initial naive schedule + * (state), the formal search policy will continue to follow its design. + * The key implementation for this structure is `Search()`, check `empty_policy.cc` for more + * details. + */ +class EmptyPolicyNode : public SearchPolicyNode { + public: + State Search(SearchTask task, int num_measure_trials, int early_stopping, + int num_measures_per_round, int verbose, ProgramMeasurer measurer, + Optional> pre_search_callbacks) final; + + static constexpr const char* _type_key = "auto_schedule.EmptyPolicy"; + TVM_DECLARE_FINAL_OBJECT_INFO(EmptyPolicyNode, SearchPolicyNode); + + private: + /*! + * \brief Use a sub function to generate several candidate states in each search round. + * \returns Several generated states + */ + Array SearchOneRound(); +}; + +/*! + * \brief Managed reference to EmptyPolicyNode. + * \sa EmptyPolicyNode + */ +class EmptyPolicy : public SearchPolicy { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EmptyPolicy, SearchPolicy, EmptyPolicyNode); +}; + +} // namespace auto_schedule +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULE_SEARCH_POLICY_EMPTY_POLICY_H_ diff --git a/src/auto_schedule/search_policy/search_policy.cc b/src/auto_schedule/search_policy/search_policy.cc new file mode 100644 index 000000000000..f8ac7ca39495 --- /dev/null +++ b/src/auto_schedule/search_policy/search_policy.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/search_policy/search_policy.cc + * \brief The base class of search policies. + */ + +#include "search_policy.h" + +#include + +namespace tvm { +namespace auto_schedule { + +TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode); +TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); + +void SearchPolicyNode::RunCallbacks(const Optional>& callbacks) { + if (callbacks) { + for (const auto& callback : callbacks.value()) { + callback->Callback(this); + } + } +} + +TVM_REGISTER_GLOBAL("auto_schedule.SearchPolicyRunCallbacks") + .set_body_typed([](SearchPolicy policy, Optional> callbacks) { + policy->RunCallbacks(callbacks); + }); + +TVM_REGISTER_GLOBAL("auto_schedule.SearchPolicySetTask") + .set_body_typed([](SearchPolicy policy, SearchTask task) { policy->cur_task = task; }); + +TVM_REGISTER_GLOBAL("auto_schedule.SearchPolicySetVerbose") + .set_body_typed([](SearchPolicy policy, int verbose) { policy->verbose = verbose; }); + +} // namespace auto_schedule +} // namespace tvm diff --git a/src/auto_schedule/search_policy/search_policy.h b/src/auto_schedule/search_policy/search_policy.h new file mode 100644 index 000000000000..47cccec93661 --- /dev/null +++ b/src/auto_schedule/search_policy/search_policy.h @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/search_policy/search_policy.h + * \brief The base class of search policies, including the abstract definition of search policy and + * other supporting data structures. + * + * The basic schedule search process for TVM Auto-scheduler is design to be: + * `Program sampling` -> `Performance Tuning`. + * + * In `Program sampling`, we use some predefined precise or heuristic rules to generate several + * initial schedules. Based on these initial starting points, we perform `Performance Tuning` which + * uses cost model based evolutionary search to select schedules with the best performance. + * + * Candidate schedules are measured against the specific hardware target. + * + * \note Adding a new search policy. + * In design, there's no need for users to implement their own search policy, our formal search + * policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule + * mechanism will be provided to enable user-defined template search to serve the same functionality + * as the current AutoTVM template. + * + * This guide is for advanced uses who have special requirements. + * 1. The only function that must be implemented is Search(), which takes a task as input and + * returns the best states found. + * 2. Information about the compute declaration of ops/subgraphs can be acquired from SearchTask. + * This structure also contains some information about the target device. (e.g. knowing the width + * of the device vector unit, we can limit the max vectorize size during schedule search) + * 3. SearchCallback provides more flexibility to do extra affairs before/after the search process. + * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states got + * during the search process. + */ + +#ifndef TVM_AUTO_SCHEDULE_SEARCH_POLICY_SEARCH_POLICY_H_ +#define TVM_AUTO_SCHEDULE_SEARCH_POLICY_SEARCH_POLICY_H_ + +#include + +#include +#include + +#include "../search_task.h" + +namespace tvm { +namespace auto_schedule { + +class ProgramMeasurer; +class SearchPolicyNode; + +/*! + * \brief Callback function to be called by the search process. + * This interface allows to do extra initializations before schedule search or extra + * check during/after the schedule search. + */ +class SearchCallbackNode : public Object { + public: + /*! + * \brief Run the registered callback function. + * \param policy A pointer to a SearchPolicyNode. + */ + virtual void Callback(SearchPolicyNode* policy) = 0; + + static constexpr const char* _type_key = "auto_schedule.SearchCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object); +}; + +/*! + * \brief Managed reference to SearchCallbackNode. + * \sa SearchCallbackNode + */ +class SearchCallback : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode); +}; + +/*! + * \brief The base class of search policies. + */ +class SearchPolicyNode : public Object { + public: + /*! \brief The current search task. */ + SearchTask cur_task; + /*! + * \brief Verbose level to control the screen output during schedule search. + * 0 for silent, 1 to output state & measure information during search process. + */ + int verbose; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("cur_task", &cur_task); + v->Visit("verbose", &verbose); + } + + /*! + * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state + * get during the search process. + * \param task The SearchTask or workload key for the computation declaration + * \param num_measure_trials Total schedules to be tried during this search. + * \param early_stopping Early stop if no better schedule is found. + * \param num_measures_per_round Max measure batch in one search round. + * \param verbose Verbose level. 0 for silent, 1 to output information during schedule + * search. + * \param measurer A ProgramMeasurer which packs ProgramBuilder & ProgramRunner inside. + * \param pre_search_callbacks SearchCallback to be called before schedule search. + * \return The best state get. + */ + virtual State Search(SearchTask task, int num_measure_trials, int early_stopping, + int num_measures_per_round, int verbose, ProgramMeasurer measurer, + Optional> pre_search_callbacks) = 0; + + /*! + * \brief Call SearchCallback with the current SearchPolicyNode + * \param callbacks SearchCallback to be called. + */ + void RunCallbacks(const Optional>& callbacks); + + static constexpr const char* _type_key = "auto_schedule.SearchPolicy"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); + + protected: + /*! + * \brief The set of already measured states. + * During the schedule search process, we may generate `equal states` through different search + * branches. (Equal States: 1. the transform steps are totally the same; 2. even with different + * steps, two states may still result in a same schedule. e.g. To split a axis with extent 512 + * to 3 parts [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can + * get a same result to split from outter to inner by factors [8, 16]) + * We store the string format of a state for redundancy check. This is used to make sure a + * measured state will never be measured again. + */ + std::unordered_set measured_states_set_; + /*! \brief The array of already measured states. This can be used in evolutionary search. */ + std::vector measured_states_vector_; + /*! \brief The throughputs of already measured states */ + std::vector measured_states_throughputs_; +}; + +/*! + * \brief Managed reference to SearchPolicyNode. + * \sa SearchPolicyNode + */ +class SearchPolicy : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchPolicy, ObjectRef, SearchPolicyNode); +}; + +} // namespace auto_schedule +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULE_SEARCH_POLICY_SEARCH_POLICY_H_ diff --git a/src/auto_schedule/search_task.cc b/src/auto_schedule/search_task.cc new file mode 100644 index 000000000000..1d7a08cc73db --- /dev/null +++ b/src/auto_schedule/search_task.cc @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/search_task.cc + * \brief Meta information and hardware parameters for a search task. + */ + +#include "search_task.h" + +#include +#include + +#include + +namespace tvm { +namespace auto_schedule { + +TVM_REGISTER_NODE_TYPE(HardwareParamsNode); +TVM_REGISTER_NODE_TYPE(SearchTaskNode); + +HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes) { + auto node = make_object(); + node->num_cores = num_cores; + node->vector_unit_bytes = vector_unit_bytes; + node->cache_line_bytes = cache_line_bytes; + data_ = std::move(node); +} + +HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target, + const Target& target_host) { + if (target->id->name == "llvm") { + return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64); + } else { + LOG(FATAL) << "No default hardware parameters for target: " << target; + } + return HardwareParams(); +} + +SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target, + Target target_host, Optional hardware_params) { + auto node = make_object(); + node->compute_dag = std::move(compute_dag); + node->workload_key = std::move(workload_key); + node->target = std::move(target); + node->target_host = std::move(target_host); + if (hardware_params) { + node->hardware_params = hardware_params.value(); + } else { + node->hardware_params = + HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host); + } + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("auto_schedule.HardwareParams") + .set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes) { + return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes); + }); + +TVM_REGISTER_GLOBAL("auto_schedule.SearchTask") + .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target, + Target target_host, Optional hardware_params) { + return SearchTask(compute_dag, workload_key, target, target_host, hardware_params); + }); + +} // namespace auto_schedule +} // namespace tvm diff --git a/src/auto_schedule/search_task.h b/src/auto_schedule/search_task.h new file mode 100644 index 000000000000..c7d2ddc533ed --- /dev/null +++ b/src/auto_schedule/search_task.h @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/search_task.h + * \brief Meta information and hardware parameters for a search task. + */ + +#ifndef TVM_AUTO_SCHEDULE_SEARCH_TASK_H_ +#define TVM_AUTO_SCHEDULE_SEARCH_TASK_H_ + +#include + +#include "compute_dag.h" + +namespace tvm { +namespace auto_schedule { + +class HardwareParams; + +/*! \brief The parameters of target hardware used to guide the search process of SearchPolicy. */ +class HardwareParamsNode : public Object { + public: + /*! \brief The number of cores. */ + int num_cores; + /*! \brief The width of vector units in bytes. */ + int vector_unit_bytes; + /*! \brief The size of cache line in bytes. */ + int cache_line_bytes; + + // GPU related parameters got from device query API + + /*! \brief The max shared memory per block. */ + int max_shared_memory_per_block{INT32_MAX}; + /*! \brief The max register memory per block. */ + int max_registers_per_block{INT32_MAX}; + /*! \brief The max threads per block. */ + int max_threads_per_block{INT32_MAX}; + /*! \brief The max vthread extent. */ + int max_vthread_extent{INT32_MAX}; + /*! \brief The thread numbers of a warp. */ + int warp_size{INT32_MAX}; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_cores", &num_cores); + v->Visit("vector_unit_bytes", &vector_unit_bytes); + v->Visit("cache_line_bytes", &cache_line_bytes); + v->Visit("max_shared_memory_per_block", &max_shared_memory_per_block); + v->Visit("max_registers_per_block", &max_registers_per_block); + v->Visit("max_threads_per_block", &max_threads_per_block); + v->Visit("max_vthread_extent", &max_vthread_extent); + v->Visit("warp_size", &warp_size); + } + + /*! + * \brief Get the default hardware params. + * \param target A `tvm.target`. + * \param target_host A `tvm.target` for host device. + * \return A HardwareParams object. + */ + static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host); + + static constexpr const char* _type_key = "auto_schedule.HardwareParams"; + TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); +}; + +/*! + * \brief Managed reference to HardwareParamsNode. + * \sa HardwareParamsNode + */ +class HardwareParams : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param num_cores The number of cores. + * \param vector_unit_bytes The width of vector units in bytes. + * \param cache_line_bytes The size of cache line in bytes. + */ + HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes); + + TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode); +}; + +/*! + * \brief The computation information and hardware parameters for a specific schedule search task. + */ +class SearchTaskNode : public Object { + public: + /*! \brief The ComputeDAG for the compute declaration. */ + ComputeDAG compute_dag; + /*! \brief The workload key for the compute declaration. */ + String workload_key; + /*! \brief The target device of this search task. */ + Target target; + /*! \brief The target host device of this search task. */ + Target target_host; + /*! \brief Hardware parameters used in this search task. */ + HardwareParams hardware_params; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("compute_dag", &compute_dag); + v->Visit("workload_key", &workload_key); + v->Visit("target", &target); + v->Visit("target_host", &target_host); + v->Visit("hardware_params", &hardware_params); + } + + static constexpr const char* _type_key = "auto_schedule.SearchTask"; + TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object); +}; + +/*! + * \brief Managed reference to SearchTaskNode. + * \sa SearchTaskNode + */ +class SearchTask : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param compute_dag The ComputeDAG for the compute declaration. + * \param workload_key The workload key for the compute declaration. + * \param target The target device of this search task. + * \param target_host The target host device of this search task. + * \param hardware_params Hardware parameters used in this search task. + */ + SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, + Optional hardware_params); + + TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); +}; + +} // namespace auto_schedule +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULE_SEARCH_TASK_H_ diff --git a/src/auto_schedule/transform_step.cc b/src/auto_schedule/transform_step.cc new file mode 100644 index 000000000000..bffb2dcfab31 --- /dev/null +++ b/src/auto_schedule/transform_step.cc @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/transform_step.cc + * \brief Transformation steps. For each schedule primitive, there is a corresponding transform + * step. + */ + +#include "transform_step.h" + +#include +#include + +#include + +#include "loop_state.h" +#include "utils.h" + +namespace tvm { +namespace auto_schedule { + +/********** Reorder **********/ +ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { + auto node = make_object(); + node->stage_id = stage_id; + for (const auto& x : after_ids) { + CHECK(x->IsInstance()); + } + node->after_ids = after_ids; + data_ = std::move(node); +} + +void ReorderStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + auto stage = (*stages)[stage_id]; + const Array& axes = stage_to_axes->at(stage); + CHECK_EQ(after_ids.size(), axes.size()); + + Array new_axes; + new_axes.reserve(axes.size()); + for (auto i : after_ids) { + new_axes.push_back(axes[i]); + } + stage.reorder(new_axes); + + stage_to_axes->Set(stage, std::move(new_axes)); + stages->Set(stage_id, std::move(stage)); +} + +String ReorderStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { + const auto& stage = (*stages)[stage_id]; + std::stringstream ss; + + ss << "s[" << CleanName(stage->op->name) << "].reorder("; + for (size_t i = 0; i < after_ids.size(); ++i) { + ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); + if (i != after_ids.size() - 1) { + ss << ", "; + } + } + ss << ")\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Split **********/ +Array ApplySplitToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + int stage_id, int iter_id, + const Array>& lengths, bool inner_to_outer) { + auto stage = (*stages)[stage_id]; + const Array& axes = stage_to_axes->at(stage); + + Array outs; + if (inner_to_outer) { + IterVar outer = axes[iter_id], inner; + for (int i = static_cast(lengths.size()) - 1; i >= 0; i--) { + IterVar to_split = outer; + stage.split(to_split, lengths[i].value(), &outer, &inner); + outs.push_back(inner); + } + outs.push_back(outer); + } else { + IterVar outer, inner = axes[iter_id]; + for (size_t i = 0; i < lengths.size(); i++) { + IterVar to_split = inner; + stage.split_by_nparts(to_split, lengths[i].value(), &outer, &inner); + outs.push_back(outer); + } + outs.push_back(inner); + } + + Array new_axes; + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + iter_id); + if (inner_to_outer) { + for (auto x = outs.rbegin(); x != outs.rend(); ++x) { + new_axes.push_back((*x)); + } + } else { + for (const auto& x : outs) { + new_axes.push_back(x); + } + } + new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end()); + + stage_to_axes->Set(stage, std::move(new_axes)); + stages->Set(stage_id, std::move(stage)); + return outs; +} + +String PrintSplitAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, int stage_id, + int iter_id, const Array>& lengths, + bool inner_to_outer) { + const auto& stage = (*stages)[stage_id]; + auto to_split = stage_to_axes->at(stage)[iter_id]; + const auto& func_name = CleanName(stage->op->name); + const auto& outs = + ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); + CHECK_EQ(outs.size(), lengths.size() + 1); + + std::stringstream ss; + int size = static_cast(lengths.size()); + if (inner_to_outer) { + for (int i = size - 1; i >= 0; i--) { + ss << CleanName(outs[size - i]->var->name_hint) << ", " + << CleanName(outs[size - i - 1]->var->name_hint) << " = s[" << func_name << "].split(" + << CleanName(to_split->var->name_hint) << ", factor=" << lengths[i] << ")\n"; + to_split = outs[size - i]; + } + } else { + for (int i = 0; i < size; i++) { + ss << CleanName(outs[i]->var->name_hint) << ", " << CleanName(outs[i + 1]->var->name_hint) + << " = s[" << func_name << "].split(" << CleanName(to_split->var->name_hint) + << ", nparts=" << lengths[i] << ")\n"; + to_split = outs[i + 1]; + } + } + + return ss.str(); +} + +SplitStep::SplitStep(int stage_id, int iter_id, Optional extent, + const Array>& lengths, bool inner_to_outer) { + auto node = make_object(); + node->stage_id = stage_id; + // Extent can be a unreducible expression in some special cases + if (extent && extent.value()->IsInstance()) { + node->extent = tvm::Downcast(extent.value()); + } + node->iter_id = iter_id; + node->lengths = lengths; + node->inner_to_outer = inner_to_outer; + data_ = std::move(node); +} + +Array SplitStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); +} + +String SplitStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); +} + +/********** Fuse **********/ +FuseStep::FuseStep(int stage_id, const Array& fused_ids) { + auto node = make_object(); + node->stage_id = stage_id; + for (const auto& x : fused_ids) { + CHECK(x->IsInstance()); + } + node->fused_ids = fused_ids; + data_ = std::move(node); +} + +IterVar FuseStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + auto stage = (*stages)[stage_id]; + const Array& axes = stage_to_axes->at(stage); + + Array to_fuse; + for (const auto& i : fused_ids) { + to_fuse.push_back(axes[i]); + } + IterVar fused_axis; + stage.fuse(to_fuse, &fused_axis); + + Array new_axes; + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front()); + new_axes.push_back(fused_axis); + new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end()); + + stage_to_axes->Set(stage, std::move(new_axes)); + stages->Set(stage_id, std::move(stage)); + return fused_axis; +} + +String FuseStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { + const auto& stage = (*stages)[stage_id]; + std::stringstream to_fuse; + + for (size_t i = 0; i < fused_ids.size(); ++i) { + to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint); + if (i != fused_ids.size() - 1) { + to_fuse << ", "; + } + } + + std::stringstream ss; + const auto& fused = ApplyToSchedule(stages, stage_to_axes); + + ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse(" + << to_fuse.str() << ")\n"; + + return ss.str(); +} + +} // namespace auto_schedule +} // namespace tvm diff --git a/src/auto_schedule/transform_step.h b/src/auto_schedule/transform_step.h new file mode 100644 index 000000000000..5e54c9de583f --- /dev/null +++ b/src/auto_schedule/transform_step.h @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/transform_step.h + * \brief Transformation steps. For each schedule primitive, there is a corresponding transform + * step. The implementation of each step consists of 2 parts: + * - transform_step.cc: How each step interacts with TE and TE's schedule primitives + * - loop_state.cc: How each step updates LoopState + * + * \note To add a new transform step: + * Take fuse step for example: + * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its construction + * function `FuseStep::FuseStep(...)` in `transform_steps.cc` + * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::PrintAsPythonAPI`. + * - In these two functions you need to lower this step with tvm's te schedule API + * 3. Implement `State::fuse` and `State::DoFuseStep`. + * - In these two functions you need to incrementally update all data structures in State with + * CopyOnWrite style + * 4. Add you step to `ComputeDAG::ApplySteps` and make sure it works. + * 5. Add log record serialization support in `struct Handler>` + * in `record.cc`. + * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test. + */ + +#ifndef TVM_AUTO_SCHEDULE_TRANSFORM_STEP_H_ +#define TVM_AUTO_SCHEDULE_TRANSFORM_STEP_H_ + +#include +#include +#include + +#include "utils.h" + +namespace tvm { +namespace auto_schedule { + +typedef Map, ObjectHash, ObjectEqual> StageToAxesMap; + +/*! + * \brief The base class of transformation steps. Each step has its corresponding tvm.te + * schedule primitives. + */ +class StepNode : public Object { + public: + /*! \brief The index of the stage. */ + int stage_id; + + static constexpr const char* _type_key = "auto_schedule.Step"; + TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); +}; + +/*! + * \brief Managed reference to StepNode. + * \sa StepNode + */ +class Step : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode); +}; + +/*! \brief Reorder step that corresponds to te::Stage::reorder */ +class ReorderStepNode : public StepNode { + public: + /*! + * \brief The iterator ids after reorder. + * This array should specify the order of all iterators. + */ + Array after_ids; + + /*! + * \brief Apply the current state to tvm.schedule + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + */ + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; + + /*! + * \brief Print step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + + static constexpr const char* _type_key = "auto_schedule.ReorderStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); +}; + +/*! + * \brief Managed reference to ReorderStepNode. + * \sa ReorderStepNode + */ +class ReorderStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be reordered. + * \param after_ids The expected indexes of the iterators after reorder. + */ + ReorderStep(int stage_id, const Array& after_ids); + + TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); +}; + +/*! + * \brief Split step that corresponds to te::Stage::split with additional + * support of multiple-level of factors + */ +class SplitStepNode : public StepNode { + public: + /*! \brief The id of the iter to split. */ + int iter_id; + /*! \brief The extent length of the axis to split. */ + Optional extent; + /*! \brief The split factors. */ + Array> lengths; + /*! + * \brief If true, the `lengths` denote the lengths of iterators + * from inner level to outer level + */ + bool inner_to_outer; + + /*! + * \brief Apply the current state to tvm.schedule + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return The iterator results after split. + */ + Array ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const; + + /*! + * \brief Print step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + + static constexpr const char* _type_key = "auto_schedule.SplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); +}; + +/*! + * \brief Managed reference to SplitStepNode. + * \sa SplitStepNode + */ +class SplitStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be split. + * \param iter_id The index of the iterator to be split. + * \param extent The extent length of the axis to split. + * \param lengths The multiple split factors. Can be None to be filled by search policy. + * \param inner_to_outer The split direction. + */ + SplitStep(int stage_id, int iter_id, Optional extent, + const Array>& lengths, bool inner_to_outer); + + TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); +}; + +/*! \brief Fuse step that corresponds to te::Stage::fuse */ +class FuseStepNode : public StepNode { + public: + /*! \brief The ids of iterators to fuse. */ + Array fused_ids; + + /*! + * \brief Apply the current state to tvm.schedule + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return The iterator result after fuse. + */ + tir::IterVar ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; + + /*! + * \brief Print step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + + static constexpr const char* _type_key = "auto_schedule.FuseStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); +}; + +/*! + * \brief Managed reference to FuseStepNode. + * \sa FuseStepNode + */ +class FuseStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be fused. + * \param fused_ids The index of the iterators to be fused. + */ + FuseStep(int stage_id, const Array& fused_ids); + + TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); +}; + +} // namespace auto_schedule +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULE_TRANSFORM_STEP_H_ diff --git a/src/auto_schedule/utils.cc b/src/auto_schedule/utils.cc new file mode 100644 index 000000000000..ecb6145268d6 --- /dev/null +++ b/src/auto_schedule/utils.cc @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/utils.cc + * \brief Common utilities. + */ + +#include "utils.h" + +namespace tvm { +namespace auto_schedule { + +NullStream& NullStream::Global() { + static NullStream stream; + return stream; +} + +} // namespace auto_schedule +} // namespace tvm diff --git a/src/auto_schedule/utils.h b/src/auto_schedule/utils.h new file mode 100644 index 000000000000..e91bc106fb51 --- /dev/null +++ b/src/auto_schedule/utils.h @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_schedule/utils.h + * \brief Common utilities. + */ + +#ifndef TVM_AUTO_SCHEDULE_UTILS_H_ +#define TVM_AUTO_SCHEDULE_UTILS_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace std { + +/*! \brief Hash function for std::pair */ +template +struct hash> { + std::size_t operator()(const std::pair& k) const { + return ::dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); + } +}; + +/*! \brief Hash function for std::tuple */ +template +struct hash> { + std::size_t operator()(const std::tuple& k) const { + return ::dmlc::HashCombine( + ::dmlc::HashCombine(std::hash()(std::get<0>(k)), std::hash()(std::get<1>(k))), + std::hash()(std::get<2>(k))); + } +}; + +} // namespace std + +namespace tvm { +namespace auto_schedule { + +/********** Utilities for Array, std::string **********/ +/*! \brief Get the first appearance index of elements in an Array */ +template +inline void GetIndices(const Array& array, const Array& to_locate, Array* indices) { + for (const auto& v : to_locate) { + auto it = std::find(array.begin(), array.end(), v); + if (it != array.end()) { + indices->push_back(it - array.begin()); + } else { + LOG(FATAL) << "Cannot find the item"; + } + } +} + +/*! \brief Get the first appearance index of an element in an Array */ +template +inline int GetIndex(const Array& array, const T& to_locate) { + for (size_t i = 0; i < array.size(); ++i) { + if (array[i] == to_locate) { + return i; + } + } + LOG(FATAL) << "Cannot find the item"; + return -1; +} + +/*! \brief Replace a sub-string to another sub-string in a string */ +inline void StrReplace(std::string* base, const std::string& from, const std::string& to) { + auto pos = base->find(from); + while (pos != std::string::npos) { + base->replace(pos, from.size(), to); + pos = base->find(from, pos + to.size()); + } +} + +/********** Utilities for TVM Containers / ByteArray **********/ +/*! \brief Compute mean of a FloatImm array */ +inline double FloatArrayMean(const Array& float_array) { + double sum = 0; + if (float_array.empty()) { + return 0.0; + } + + for (const auto& x : float_array) { + auto floatimm = x.as(); + CHECK(floatimm != nullptr); + sum += floatimm->value; + } + return sum / float_array.size(); +} + +/********** Other Utilities **********/ +/*! \brief Get an int value from an Expr */ +inline int64_t GetIntImm(const PrimExpr& expr) { + auto pint = expr.as(); + CHECK(pint != nullptr); + return pint->value; +} + +/*! \brief Compute the product of the lengths of axes */ +inline int64_t AxisLengthProd(const Array& axes) { + int64_t ret = 1.0; + for (const auto& x : axes) { + if (const IntImmNode* imm = x->dom->extent.as()) { + ret *= imm->value; + } else { + return -1.0; + } + } + return ret; +} + +/*! + * \brief Clean the name of an iterator to make it valid in python code. + * \param str The original name. + * \return The cleaned name. + */ +inline std::string CleanName(const std::string& str) { + std::string ret = str; + StrReplace(&ret, ".", "_"); + StrReplace(&ret, "@", "_"); + StrReplace(&ret, "outer", "o"); + StrReplace(&ret, "inner", "i"); + return ret; +} + +/*! \brief An empty output stream */ +class NullStream : public std::ostream { + public: + NullStream() : std::ostream(nullptr) {} + NullStream(const NullStream&) : std::ostream(nullptr) {} + static NullStream& Global(); +}; + +template +NullStream& operator<<(NullStream& os, const T& value) { + return os; +} + +/*! \brief Get std cout with verbose control */ +inline std::ostream& StdCout(int verbose, int setting = 1) { + return verbose >= setting ? std::cout : NullStream::Global(); +} + +/*! \brief Print multiple chars */ +inline std::string Chars(const char& str, int times) { + std::stringstream ret; + for (int i = 0; i < times; ++i) { + ret << str; + } + return ret.str(); +} + +/*! \brief Print a title */ +inline void PrintTitle(const std::string& title, int verbose) { + StdCout(verbose) << Chars('-', 60) << "\n" + << Chars('-', 25) << " [ " << title << " ]\n" + << Chars('-', 60) << std::endl; +} + +} // namespace auto_schedule +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULE_UTILS_H_ diff --git a/tests/python/unittest/test_auto_schedule_common.py b/tests/python/unittest/test_auto_schedule_common.py new file mode 100644 index 000000000000..691c7e767b6f --- /dev/null +++ b/tests/python/unittest/test_auto_schedule_common.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Common functions for auto_schedule test cases""" + +import threading + +from tvm import te, auto_schedule +import topi + + +@auto_schedule.register_workload +def matmul_auto_schedule_test(N, M, K): + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') + return [A, B, C] + + +@auto_schedule.register_workload("matmul_auto_schedule_test_rename_1") +def matmul_auto_schedule_test_rename_0(N, M, K): + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') + return [A, B, C] + + +def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): + data = te.placeholder((N, CI, H, W), name='Data') + kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='Kernel') + bias = te.placeholder((CO, 1, 1), name='Bias') + bn_scale = te.placeholder((CO, 1, 1), name='Bn_scale') + bn_offset = te.placeholder((CO, 1, 1), name='Bn_offset') + + OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] + bias[j, 0, 0], + name='Bias_add') + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] * bn_scale[j, 0, 0], + name='Bn_mul') + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] + bn_offset[j, 0, 0], + name='Bn_add') + out = topi.nn.relu(conv) + + return [data, kernel, bias, bn_offset, bn_scale, out] + + +def get_tiled_matmul(): + A, B, C = matmul_auto_schedule_test(512, 512, 512) + dag = auto_schedule.ComputeDAG([A, B, C]) + + s0 = dag.get_init_state() + its0 = s0.split(C, s0[C].iters[0], [4, 8, 8]) + its1 = s0.split(C, s0[C].iters[4], [8, 4, 4]) + s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3], + s0[C].iters[8]]) + + return dag, s0 + + +class PropagatingThread(threading.Thread): + def run(self): + self.exc = None + try: + self.ret = self._target(*self._args, **self._kwargs) + except BaseException as e: + self.exc = e + + def join(self): + super(PropagatingThread, self).join() + if self.exc: + raise self.exc + return self.ret diff --git a/tests/python/unittest/test_auto_schedule_compute_dag.py b/tests/python/unittest/test_auto_schedule_compute_dag.py new file mode 100644 index 000000000000..8a4f836765eb --- /dev/null +++ b/tests/python/unittest/test_auto_schedule_compute_dag.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test ComputeDAG (replay, infer bound)""" + +import tvm +from tvm import auto_schedule, te + +from test_auto_schedule_common import get_tiled_matmul + + +def test_apply_steps(): + dag, s = get_tiled_matmul() + dag.print_python_code_from_state(s) + sch, tensors = dag.apply_steps_from_state(s) + stmt = tvm.lower(sch, tensors, simple_mode=True) + + +def test_infer_bound(): + dag, s = get_tiled_matmul() + s = dag.infer_bound_from_state(s) + + +def test_estimate_flop(): + dag, s = get_tiled_matmul() + assert abs(dag.flop_ct - 2 * 512 ** 3) < 0.5 + + +if __name__ == "__main__": + test_apply_steps() + test_infer_bound() + test_estimate_flop() diff --git a/tests/python/unittest/test_auto_schedule_loop_state.py b/tests/python/unittest/test_auto_schedule_loop_state.py new file mode 100644 index 000000000000..ed54da513d16 --- /dev/null +++ b/tests/python/unittest/test_auto_schedule_loop_state.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test loop state and schedule primitives""" + +import numpy as np + +import tvm +from tvm import auto_schedule, te +import topi + +from test_auto_schedule_common import matmul_auto_schedule_test, conv2d_nchw_bn_relu + + +def test_split_fuse_reorder(): + A, B, C = matmul_auto_schedule_test(512, 512, 512) + dag = auto_schedule.ComputeDAG([A, B, C]) + s0 = dag.get_init_state() + i, j, k = s0[C].iters + + assert i.range.extent == 512 + + io, ii = s0.split(C, i, [16]) + assert s0[C].iters[0] == io + assert s0[C].iters[1] == ii + assert io.range.extent == 32 + assert ii.range.extent == 16 + + jo, ji = s0.split(C, j, [8]) + assert jo.range.extent == 64 + assert ji.range.extent == 8 + + s0.reorder(C, [io, jo, k, ji, ii]) + assert s0[C].iters[2].range.extent == 512 + + fused_it = s0.fuse(C, [io, jo]) + assert fused_it.range.extent == 2048 + + s1 = dag.get_init_state() + i, j, _ = s1[C].iters + i1, i2, i3 = s1.split(C, i, [8, 2]) + j1, j2, j3 = s1.split(C, j, [32, 8], False) + assert s1[C].iters[0].range.extent == 32 + assert s1[C].iters[1].range.extent == 8 + assert s1[C].iters[2].range.extent == 2 + assert s1[C].iters[3].range.extent == 32 + assert s1[C].iters[4].range.extent == 8 + assert s1[C].iters[5].range.extent == 2 + +if __name__ == "__main__": + test_split_fuse_reorder() diff --git a/tests/python/unittest/test_auto_schedule_measure.py b/tests/python/unittest/test_auto_schedule_measure.py new file mode 100644 index 000000000000..52d016de0756 --- /dev/null +++ b/tests/python/unittest/test_auto_schedule_measure.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Test measurement and log serialization. """ + +import tvm +from tvm import auto_schedule +import tempfile + +from test_auto_schedule_common import get_tiled_matmul + + +def test_record(): + dag, s = get_tiled_matmul() + + if not tvm.runtime.enabled("llvm"): + return + target = tvm.target.create("llvm") + task = auto_schedule.SearchTask(dag, "test", target) + + inp = auto_schedule.measure.MeasureInput(task, s) + res = auto_schedule.measure.MeasureResult([0.1], 0, "", 0.2, 1) + + with tempfile.NamedTemporaryFile() as fp: + auto_schedule.save_records(fp.name, [inp], [res]) + + log_reader = auto_schedule.RecordReader(fp.name) + inputs, results = log_reader.read_lines() + assert len(inputs) == 1 + + s1 = dag.infer_bound_from_state(s) + s2 = dag.infer_bound_from_state(inputs[0].state) + + assert s1 == s2 + assert not (s1 == dag.get_init_state()) + + +def test_measure_local_builder_runner(): + dag, s0 = get_tiled_matmul() + + if not tvm.runtime.enabled("llvm"): + return + tgt = tvm.target.create("llvm") + task = auto_schedule.SearchTask(dag, "test", tgt) + + minp = auto_schedule.MeasureInput(task, s0) + local_builder = auto_schedule.LocalBuilder() + local_runner = auto_schedule.LocalRunner() + + bress = local_builder.build([minp]) + assert bress[0].error_no == 0 + mress = local_runner.run([minp], bress) + assert mress[0].error_no == 0 + + +if __name__ == "__main__": + test_record() + test_measure_local_builder_runner() diff --git a/tests/python/unittest/test_auto_schedule_search_policy.py b/tests/python/unittest/test_auto_schedule_search_policy.py new file mode 100644 index 000000000000..9e08218dcbce --- /dev/null +++ b/tests/python/unittest/test_auto_schedule_search_policy.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test search policy""" + +import random +import numpy as np +import tempfile + +import tvm +from tvm import auto_schedule + +from test_auto_schedule_common import matmul_auto_schedule_test, PropagatingThread + +def search_common(workload=matmul_auto_schedule_test, target="llvm", search_policy = auto_schedule.EmptyPolicy(), + seed=random.randint(1, 1 << 30), runner='local', cost_model=None, + num_measure_trials=2, params=None, pre_search_callbacks=None): + print("Test %s schedule search with the default search policy" % (target)) + + random.seed(seed) + N = 128 + workload_key = auto_schedule.make_workload_key(workload, (N, N, N)) + dag = auto_schedule.ComputeDAG(workload_key) + target = tvm.target.create(target) + task = auto_schedule.SearchTask(dag, workload_key, target) + + with tempfile.NamedTemporaryFile() as fp: + log_file = fp.name + + tuning_options = auto_schedule.TuningOptions(num_measure_trials=num_measure_trials, runner=runner, + verbose=0, + measure_callbacks=[auto_schedule.RecordToFile(log_file)], + pre_search_callbacks=pre_search_callbacks) + sch, args = auto_schedule.auto_schedule(task, search_policy, tuning_options) + inp, res = auto_schedule.load_best(log_file, workload_key, target) + + print("==== Python Code ====") + print(dag.print_python_code_from_state(inp.state)) + + try: + print("==== Lowered Stmt ====") + print(tvm.lower(sch, args, simple_mode=True)) + mod = tvm.build(sch, args, target) + + ctx = tvm.context(str(target), 0) + dtype = dag.tensors[0].dtype + a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx) + c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx) + mod(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), np.dot( + a.asnumpy(), b.asnumpy()), rtol=1e-5) + print("==== Verification passed ====") + except Exception: + raise Exception("Error encountered with seed: %d" % (seed)) + print() + + +def test_workload_registry_search_basic(): + if not tvm.runtime.enabled("llvm"): + return + # wrap the search in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool + t = PropagatingThread(target=search_common, kwargs={'seed': 944563397}) + t.start() + t.join() + t = PropagatingThread(target=search_common, + kwargs={'seed': 944563397, 'workload': "matmul_auto_schedule_test"}) + t.start() + t.join() + t = PropagatingThread(target=search_common, + kwargs={'seed': 944563397, 'workload': "matmul_auto_schedule_test_rename_1"}) + t.start() + t.join() + +if __name__ == "__main__": + test_workload_registry_search_basic()