Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoScheduler] Support early_stopping per task #7377

Merged
merged 7 commits into from
Feb 5, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions python/tvm/auto_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def __init__(
# task_cts[i] saves how many times task i is tuned
self.task_cts = [0 for _ in range(len(self.tasks))]

# task_best_cts[i] saves the round task i found the best latency
comaniac marked this conversation as resolved.
Show resolved Hide resolved
self.task_best_cts = [0 for _ in range(len(self.tasks))]

# task_costs_history[i] saves the latency history of task i
self.task_costs_history = [[] for _ in range(len(self.tasks))]

Expand Down Expand Up @@ -281,13 +284,14 @@ def tune(
search_policy="default",
search_policy_params=None,
adapative_training=False,
per_task_early_stopping=None,
):
"""Tune a batch of tasks together.

Parameters
----------
tune_option: TuningOptions
The options of tuning
The tuning options applied to all tasks.
search_policy: : Union[str, List[SearchPolicy]] = "default"
The list of search policies.
If it is str,
Expand All @@ -299,10 +303,17 @@ def tune(
adapative_training : bool = False
Option used by XGBModel to reduce the model training frequency when there're
too many logs.
per_task_early_stopping : Optional[int]
Stop tuning a task early if getting no improvement after n measurements.
"""
# init members
self.tune_option = tune_option
early_stopping = 1e20 if tune_option.early_stopping < 0 else tune_option.early_stopping
self.early_stopping_all = (
1e20 if tune_option.early_stopping < 0 else tune_option.early_stopping
)
self.early_stopping_task = (
1e20 if per_task_early_stopping is None else per_task_early_stopping
)

self.measurer = ProgramMeasurer(
tune_option.builder,
Expand Down Expand Up @@ -417,13 +428,13 @@ def tune(
if self.cur_score < self.best_score:
self.best_score = self.cur_score
self.best_ct = self.ct
elif self.ct - self.best_ct >= early_stopping and all(
elif self.ct - self.best_ct >= self.early_stopping_all and all(
cost < 1e9 for cost in self.best_costs
):
if self.tune_option.verbose >= 1:
print(
"Stop early since no performance improvement in the last "
+ str(early_stopping)
+ str(self.early_stopping_all)
+ " measurement trials."
)
break
Expand All @@ -439,15 +450,22 @@ def _tune_task(self, task_idx):
self.num_measures_per_round, self.measurer
)

self.task_cts[task_idx] += 1

for res in measure_results:
cost = array_mean(res.costs)
if cost < self.best_costs[task_idx]:
self.task_best_cts[task_idx] = self.task_cts[task_idx]
self.best_costs[task_idx] = cost

if len(measure_inputs) == 0:
# Stop tuning this task in the rest of the process if its search space has been
# fully explored or it has no improvement for a long while.
no_change_trials = (
self.task_cts[task_idx] - self.task_best_cts[task_idx]
) * self.num_measures_per_round
if len(measure_inputs) == 0 or no_change_trials > self.early_stopping_task:
self.dead_tasks.add(task_idx)

self.task_cts[task_idx] += 1
self.task_costs_history[task_idx].append(self.best_costs[task_idx])

self.ct += len(measure_inputs)
Expand Down Expand Up @@ -494,17 +512,24 @@ def _restore_status(self, log_file, num_measures_per_round):
if task_idx is None:
continue

self.task_cts[task_idx] += 1

if res.error_no == 0:
self.best_costs[task_idx] = min(self.best_costs[task_idx], array_mean(res.costs))
cost = array_mean(res.costs)
if self.best_costs[task_idx] < cost:
self.best_costs[task_idx] = cost
self.task_best_cts = self.task_cts[task_idx]

self.task_cts[task_idx] += 1
for idx in range(len(self.tasks)):
if self.task_cts[idx] - self.task_best_cts[idx] > self.early_stopping_task:
self.dead_tasks.add(idx)

for i in range(len(self.tasks)):
# The computation of taks_cts is just an estimation.
# The estimation may not be accurate if the log file is changed externally or
# `num_measures_per_round` is different from the last tuning.
self.task_cts[i] = int(self.task_cts[i] / num_measures_per_round + 0.5)
self.task_costs_history[i].append(self.best_costs[i])
self.task_cts[idx] = int(self.task_cts[idx] / num_measures_per_round + 0.5)
self.task_best_cts[idx] = int(self.task_best_cts[idx] / num_measures_per_round + 0.5)
self.task_costs_history[idx].append(self.best_costs[idx])

self.cur_score = self._compute_score(self.best_costs)

Expand Down