diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 420b5f765a97..b6b05298aef7 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -246,6 +246,9 @@ def __init__( # task_cts[i] saves how many times task i is tuned self.task_cts = [0 for _ in range(len(self.tasks))] + # task_best_cts[i] saves the round task i found the best latency + self.task_best_cts = [0 for _ in range(len(self.tasks))] + # task_costs_history[i] saves the latency history of task i self.task_costs_history = [[] for _ in range(len(self.tasks))] @@ -281,13 +284,14 @@ def tune( search_policy="default", search_policy_params=None, adapative_training=False, + per_task_early_stopping=None, ): """Tune a batch of tasks together. Parameters ---------- tune_option: TuningOptions - The options of tuning + The tuning options applied to all tasks. search_policy: : Union[str, List[SearchPolicy]] = "default" The list of search policies. If it is str, @@ -299,10 +303,17 @@ def tune( adapative_training : bool = False Option used by XGBModel to reduce the model training frequency when there're too many logs. + per_task_early_stopping : Optional[int] + Stop tuning a task early if getting no improvement after n measurements. """ # init members self.tune_option = tune_option - early_stopping = 1e20 if tune_option.early_stopping < 0 else tune_option.early_stopping + self.early_stopping_all = ( + 1e20 if tune_option.early_stopping < 0 else tune_option.early_stopping + ) + self.early_stopping_task = ( + 1e20 if per_task_early_stopping is None else per_task_early_stopping + ) self.measurer = ProgramMeasurer( tune_option.builder, @@ -417,13 +428,13 @@ def tune( if self.cur_score < self.best_score: self.best_score = self.cur_score self.best_ct = self.ct - elif self.ct - self.best_ct >= early_stopping and all( + elif self.ct - self.best_ct >= self.early_stopping_all and all( cost < 1e9 for cost in self.best_costs ): if self.tune_option.verbose >= 1: print( "Stop early since no performance improvement in the last " - + str(early_stopping) + + str(self.early_stopping_all) + " measurement trials." ) break @@ -439,15 +450,22 @@ def _tune_task(self, task_idx): self.num_measures_per_round, self.measurer ) + self.task_cts[task_idx] += 1 + for res in measure_results: cost = array_mean(res.costs) if cost < self.best_costs[task_idx]: + self.task_best_cts[task_idx] = self.task_cts[task_idx] self.best_costs[task_idx] = cost - if len(measure_inputs) == 0: + # Stop tuning this task in the rest of the process if its search space has been + # fully explored or it has no improvement for a long while. + no_change_trials = ( + self.task_cts[task_idx] - self.task_best_cts[task_idx] + ) * self.num_measures_per_round + if len(measure_inputs) == 0 or no_change_trials > self.early_stopping_task: self.dead_tasks.add(task_idx) - self.task_cts[task_idx] += 1 self.task_costs_history[task_idx].append(self.best_costs[task_idx]) self.ct += len(measure_inputs) @@ -494,17 +512,24 @@ def _restore_status(self, log_file, num_measures_per_round): if task_idx is None: continue + self.task_cts[task_idx] += 1 + if res.error_no == 0: - self.best_costs[task_idx] = min(self.best_costs[task_idx], array_mean(res.costs)) + cost = array_mean(res.costs) + if self.best_costs[task_idx] < cost: + self.best_costs[task_idx] = cost + self.task_best_cts = self.task_cts[task_idx] - self.task_cts[task_idx] += 1 + for idx in range(len(self.tasks)): + if self.task_cts[idx] - self.task_best_cts[idx] > self.early_stopping_task: + self.dead_tasks.add(idx) - for i in range(len(self.tasks)): # The computation of taks_cts is just an estimation. # The estimation may not be accurate if the log file is changed externally or # `num_measures_per_round` is different from the last tuning. - self.task_cts[i] = int(self.task_cts[i] / num_measures_per_round + 0.5) - self.task_costs_history[i].append(self.best_costs[i]) + self.task_cts[idx] = int(self.task_cts[idx] / num_measures_per_round + 0.5) + self.task_best_cts[idx] = int(self.task_best_cts[idx] / num_measures_per_round + 0.5) + self.task_costs_history[idx].append(self.best_costs[idx]) self.cur_score = self._compute_score(self.best_costs)