diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py index b0b98d8d0f56..f2d7536bea88 100644 --- a/python/tvm/auto_scheduler/dispatcher.py +++ b/python/tvm/auto_scheduler/dispatcher.py @@ -30,6 +30,7 @@ from tvm.tir.expr import FloatImm from .measure_record import load_records +from .utils import calc_workload_dis_factor, decode_workload_key logger = logging.getLogger("auto_scheduler") @@ -126,18 +127,53 @@ class ApplyHistoryBest(DispatchContext): If is str, then it should be the filename of a records log file. Each row of this file is an encoded record pair. Otherwise, it is an iterator. n_lines: Optional[int] - if it is not None, only load the first `n_lines` lines of log + if it is not None, only load the first `n_lines` lines of log. + include_compatible: bool + When set to True, compatible records will also be considered. """ - def __init__(self, records, n_lines=None): + def __init__(self, records, n_lines=None, include_compatible=False): super(ApplyHistoryBest, self).__init__() + self.include_compatible = include_compatible + # Dict[str (target key), + # Dict[str (workload hash), + # Dict[tuple (workload args), tuple (State, cost)]]] self.best_by_targetkey = {} self.best_by_model = {} self._best_user_defined = {} self.load(records, n_lines) + @staticmethod + def get_workload_entry(best_records, target_key, workload_key): + """Get the entry of the target key and workload key hash in the given best record map. + + Parameters + ---------- + best_records: Dict[str, Dict[str, Dict[str, Any]]] + The best record map. + target_key: str + The first key to the best_records. + workload_key: str + The workload key that can be decoded to workload hash and args. + + Returns + ------- + entry: Dict[str, Any] + The entry in best_records with target key and workload hash. + workload_hash: str + The workload hash decoded from workload_key. + workload_args: Tuple[Any, ...] + The hashable tuple of workload args decoded from workload_key. + """ + workload_hash, workload_args = decode_workload_key(workload_key) + if target_key not in best_records: + best_records[target_key] = {} + if workload_hash not in best_records[target_key]: + best_records[target_key][workload_hash] = {} + return best_records[target_key][workload_hash], workload_hash, workload_args + def load(self, records, n_lines=None): """Load records to this dispatch context @@ -171,29 +207,32 @@ def load(self, records, n_lines=None): if res.error_no != 0: continue + costs = [x.value for x in res.costs if isinstance(x, FloatImm)] + cost = np.mean(costs) + # use target keys in tvm target system as key to build best map for k in inp.task.target.keys: - key = (k, inp.task.workload_key) - if key not in best_by_targetkey: - best_by_targetkey[key] = (inp, res) + entry, _, workload_args = self.get_workload_entry( + best_by_targetkey, k, inp.task.workload_key + ) + if workload_args not in entry: + entry[workload_args] = (inp.state, cost) else: - _, other_res = best_by_targetkey[key] - other_costs = [x.value for x in other_res.costs if isinstance(x, FloatImm)] - costs = [x.value for x in res.costs if isinstance(x, FloatImm)] - if np.mean(other_costs) > np.mean(costs): - best_by_targetkey[key] = (inp, res) + _, other_cost = entry[workload_args] + if other_cost > cost: + entry[workload_args] = (inp.state, cost) # use model as key to build best map - key = (inp.task.target.model, inp.task.workload_key) - if key not in best_by_model: + entry, _, workload_args = self.get_workload_entry( + best_by_model, inp.task.target.model, inp.task.workload_key + ) + if workload_args not in entry: if inp.task.target.model != "unknown": - best_by_model[key] = (inp, res) + entry[workload_args] = (inp.state, cost) else: - _, other_res = best_by_model[key] - other_costs = [x.value for x in other_res.costs if isinstance(x, FloatImm)] - costs = [x.value for x in res.costs if isinstance(x, FloatImm)] - if np.mean(other_costs) > np.mean(costs): - best_by_model[key] = (inp, res) + _, other_cost = entry[workload_args] + if other_cost > cost: + entry[workload_args] = (inp.state, cost) logger.debug("Finish loading %d records", counter) @@ -205,31 +244,61 @@ def _query_inside(self, target, workload_key): " above the dispatcher call. So does other target. " ) + def match_record(best_records, target_key, workload_key): + """The helper function to match the record in the given map + and return the matched state, or None if no match. + """ + ret = None + + entry, workload_hash, workload_args = self.get_workload_entry( + best_records, target_key, workload_key + ) + if workload_args in entry: + ret = entry[workload_args][0] + elif self.include_compatible: + best_cost = float("inf") + for args, val in entry.items(): + dis_f = calc_workload_dis_factor( + (workload_hash, workload_args), (workload_hash, args) + ) + if dis_f == float("inf"): + continue + + state, cost = val + cost *= dis_f + if ret is None or cost < best_cost: + best_cost = cost + ret = state + return ret + # first try matching by model - key = (target.model, workload_key) - if key in self._best_user_defined: - return self._best_user_defined[key] - if key in self.best_by_model: - return self.best_by_model[key][0].state + ret = match_record(self._best_user_defined, target.model, workload_key) + if ret is not None: + return ret + ret = match_record(self.best_by_model, target.model, workload_key) + if ret is not None: + return ret # then try matching by target key for k in target.keys: - key = (k, workload_key) - if key in self._best_user_defined: - return self._best_user_defined[key] - if key in self.best_by_targetkey: - return self.best_by_targetkey[key][0].state + ret = match_record(self._best_user_defined, k, workload_key) + if ret is not None: + return ret + ret = match_record(self.best_by_targetkey, k, workload_key) + if ret is not None: + return ret return None def update(self, target, workload_key, state): - model = target.model - key = (model, workload_key) - self._best_user_defined[key] = state + entry, _, workload_args = self.get_workload_entry( + self._best_user_defined, target.model, workload_key + ) + entry[workload_args] = (state, 1) for k in target.keys: - key = (k, workload_key) - self._best_user_defined[key] = state + entry, _, _ = self.get_workload_entry(self._best_user_defined, k, workload_key) + entry[workload_args] = (state, 1) class FallbackContext(DispatchContext): diff --git a/python/tvm/auto_scheduler/measure_record.py b/python/tvm/auto_scheduler/measure_record.py index 9eaef189e081..200d24fa7d50 100644 --- a/python/tvm/auto_scheduler/measure_record.py +++ b/python/tvm/auto_scheduler/measure_record.py @@ -27,7 +27,7 @@ import tvm._ffi from tvm.runtime import Object from .measure import MeasureErrorNo, MeasureCallback -from .utils import decode_workload_key +from .utils import calc_workload_dis_factor, decode_workload_key from . import _ffi_api logger = logging.getLogger("auto_scheduler") @@ -130,65 +130,6 @@ def __iter__(self): yield ret[0], ret[1] # (input, result) -def calc_workload_dis_factor(target_workload_key, workload_key): - """Calculate the distance factor of the workload to the target workload. - If two workloads are not compatible at all (i.e., different compute DAG or function), - then the distance factor is "inf". Otherwise, we calculate the factor by traversing - the workload arguments, which are the arguments of the compute function, - or the output shapes for the ComputeDAG. The factor is calculated by the following rules: - - 1. For non-zero integer values: `product(target_arg / candidate_arg)`. - 2. For non-integer or zero values: "inf" if not equal else 1. - - As a result, factor=1 is the optimal when two workloads are identical. - - Parameters - ---------- - target_workload_key: str - The target workload key in JSON string. - - workload_key: str - The candidate workload key in JSON string. - - Returns - ------- - dis_f: float - The distance factor. - """ - - def flatten_list(inp): - ret = [] - for elt in inp: - if isinstance(elt, list): - ret += flatten_list(elt) - else: - ret.append(elt) - return ret - - target_key, target_args = decode_workload_key(target_workload_key) - target_args = flatten_list(target_args) if target_args is not None else [] - key, args = decode_workload_key(workload_key) - args = flatten_list(args) if args is not None else [] - - # Not even the same func/DAG. - if key != target_key or len(target_args) != len(args): - return float("inf") - - dis_f = 1 - for target_arg, arg in zip(target_args, args): - if isinstance(target_arg, int): - if target_arg == 0 or arg == 0: - if target_arg != arg: - return float("inf") - elif target_arg % arg != 0: - return float("inf") - else: - dis_f *= target_arg / arg - elif target_arg != arg: - return float("inf") - return dis_f - - def load_record_from_string(record): """ Load the measure record from string. @@ -304,7 +245,9 @@ def load_best_record(filename, workload_key=None, target=None, include_compatibl cost = np.mean(costs) if workload_key is not None: - dis_f = calc_workload_dis_factor(workload_key, inp.task.workload_key) + dis_f = calc_workload_dis_factor( + decode_workload_key(workload_key), decode_workload_key(inp.task.workload_key) + ) if dis_f == float("inf"): continue if not include_compatible and dis_f != 1: diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py index fd25fdb783f7..8aa33e6775f8 100644 --- a/python/tvm/auto_scheduler/utils.py +++ b/python/tvm/auto_scheduler/utils.py @@ -57,18 +57,77 @@ def decode_workload_key(workload_key): ------- name: str The workload function name or the DAG hash. - args: Optional[List[Any]] - The arguments of the workload, or None if the workload key format is not decodeable. + args: Optional[Tuple[Any, ...]] + The flatten arguments in a tuple, or None if the workload key format is not decodeable. """ + + def flatten_list(inp): + ret = [] + for elt in inp: + if isinstance(elt, list): + ret += flatten_list(elt) + else: + ret.append(elt) + return ret + try: key_list = json.loads(workload_key) if isinstance(key_list, list) and len(key_list) >= 1: - return key_list[0], key_list[1:] + return key_list[0], tuple(flatten_list(key_list[1:])) except json.decoder.JSONDecodeError: pass return workload_key, None +def calc_workload_dis_factor(target_workload_pair, workload_pair): + """Calculate the distance factor of the workload to the target workload. + If two workloads are not compatible at all (i.e., different compute DAG or function), + then the distance factor is "inf". Otherwise, we calculate the factor by traversing + the workload arguments, which are the arguments of the compute function, + or the output shapes for the ComputeDAG. The factor is calculated by the following rules: + + 1. For non-zero integer values: `product(target_arg / candidate_arg)`. + 2. For non-integer or zero values: "inf" if not equal else 1. + + As a result, factor=1 is the optimal when two workloads are identical. + + Parameters + ---------- + target_workload_pair: Tuple[str, Optional[Tuple[Any, ...]]] + The target workload pair: (hash, argument tuple). + + workload_pair: Tuple[str, Optional[Tuple[Any, ...]]] + The candidate workload pair: (hash, argument tuple). + + Returns + ------- + dis_f: float + The distance factor. + """ + target_key, target_args = target_workload_pair + target_args = target_args if target_args is not None else [] + key, args = workload_pair + args = args if args is not None else [] + + # Not even the same func/DAG. + if key != target_key or len(target_args) != len(args): + return float("inf") + + dis_f = 1 + for target_arg, arg in zip(target_args, args): + if isinstance(target_arg, int): + if target_arg == 0 or arg == 0: + if target_arg != arg: + return float("inf") + elif target_arg % arg != 0: + return float("inf") + else: + dis_f *= target_arg / arg + elif target_arg != arg: + return float("inf") + return dis_f + + def get_func_name(func): """Get name of a function. diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 3b074b273358..041fb7ee76d3 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -202,35 +202,36 @@ def test_recover_measure_input(): def test_workload_dis_factor(): - calc = auto_scheduler.measure_record.calc_workload_dis_factor + calc = auto_scheduler.utils.calc_workload_dis_factor + decode = auto_scheduler.utils.decode_workload_key # Identical target_wkl_key = json.dumps( ["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], [1, 1], "float32"] ) - assert calc(target_wkl_key, target_wkl_key) == 1 + assert calc(decode(target_wkl_key), decode(target_wkl_key)) == 1 # Compatible with a factor wkl_key = json.dumps(["func1", [1, 3, 112, 112], [32, 3, 3, 3], [0, 0], [1, 1], "float32"]) - assert calc(target_wkl_key, wkl_key) == 8 * 2 * 2 + assert calc(decode(target_wkl_key), decode(wkl_key)) == 8 * 2 * 2 # Incompatible argument with zeros wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [1, 1], [1, 1], "float32"]) - assert calc(target_wkl_key, wkl_key) == float("inf") + assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf") wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], [0, 0], "float32"]) - assert calc(target_wkl_key, wkl_key) == float("inf") + assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf") # Incompatible non-integter argument wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], [1, 1], "int8"]) - assert calc(target_wkl_key, wkl_key) == float("inf") + assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf") # Incompatible function wkl_key = json.dumps(["func2", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], [1, 1], "float32"]) - assert calc(target_wkl_key, wkl_key) == float("inf") + assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf") # Incompatible due to non-dividable factor wkl_key = json.dumps(["func1", [8, 3, 223, 223], [32, 3, 3, 3], [0, 0], [1, 1], "float32"]) - assert calc(target_wkl_key, wkl_key) == float("inf") + assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf") def test_measure_local_builder_runner(): @@ -322,6 +323,7 @@ def test_measure_target_host(): test_record_follow_split_follow_fused_split() test_record_pragma_storage_align_rfactor() test_recover_measure_input() + test_workload_dis_factor() test_measure_local_builder_runner() test_measure_local_builder_rpc_runner() test_measure_target_host()