Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoScheduler] Task scheduler callbacks #6945

Merged
merged 6 commits into from
Nov 24, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 120 additions & 32 deletions python/tvm/auto_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor
Programs for Deep Learning." (OSDI 2020).
"""

import os
import time
import math
import logging
Expand Down Expand Up @@ -168,6 +168,9 @@ class TaskScheduler:
The parameter used for 'gradient' strategy
backward_window_size: int = 3
The parameter used for 'gradient' strategy
callbacks: Optional[List[TaskSchedulerCallback]]
The task scheduler callbacks that will be called before and after tuning a task.
If None, then PrintTableInfo callback will be used.
"""

def __init__(
Expand All @@ -182,6 +185,7 @@ def __init__(
beta: float = 2,
gamma: float = 0.5,
backward_window_size: int = 3,
callbacks=None,
):
self.tasks = tasks
if objective_func: # use custom objective function
Expand All @@ -199,6 +203,7 @@ def __init__(
self.beta = beta
self.gamma = gamma
self.backward_window_size = backward_window_size
self.callbacks = callbacks if callbacks is not None else [PrintTableInfo()]

assert len(self.tasks) != 0, "No tasks"
assert self.strategy in ["round-robin", "gradient"]
Expand Down Expand Up @@ -374,39 +379,12 @@ def tune(self, tune_option, search_policy="default"):
)
break

def _print_table_info(self, next_task_idx):
# table header
_ffi_api.PrintTitle("Task Scheduler")
print("| ID | Latency (ms) | Speed (GFLOPS) | Trials |")
print("-------------------------------------------------")

# content
for i in range(len(self.tasks)):
id_str = "%d" % i
latency_str = "%.3f" % (1e3 * self.best_costs[i]) if self.best_costs[i] < 1e9 else "-"
speed_str = (
"%.2f" % (self.tasks[i].compute_dag.flop_ct / self.best_costs[i] / 1e9)
if self.best_costs[i] < 1e9
else "-"
)
trials_str = "%d" % (self.task_cts[i] * self.num_measures_per_round)
print("| %4s | %12s | % 14s | %6s |" % (id_str, latency_str, speed_str, trials_str))
print("-------------------------------------------------")

# overall info
if all(cost < 1e9 for cost in self.best_costs):
total_latency_str = "%.3f" % (self.cur_score * 1e3)
else:
total_latency_str = "-"
print(
"Estimated total latency: %s ms\tTrials: %d\tUsed time : %.0f s\tNext ID: %d\t"
% (total_latency_str, self.ct, time.time() - self.tic, next_task_idx)
)

def _tune_task(self, task_idx):
"""Tune the select task for one round"""
if self.tune_option.verbose >= 1:
self._print_table_info(task_idx)

# Run pre-tune callbacks
for callback in self.callbacks:
callback.pre_tune(self, task_idx)

measure_inputs, measure_results = self.search_policies[task_idx].continue_search_one_round(
self.num_measures_per_round, self.measurer
Expand All @@ -426,6 +404,10 @@ def _tune_task(self, task_idx):
self.ct += len(measure_inputs)
self.cur_score = self._compute_score(self.best_costs)

# Run post-tune callbacks
for callback in self.callbacks:
callback.post_tune(self, task_idx)

def _compute_score(self, costs):
"""compute the objective function"""
return self.objective_func(costs)
Expand Down Expand Up @@ -478,3 +460,109 @@ def _restore_status(self, log_file, num_measures_per_round):
self.cur_score = self._compute_score(self.best_costs)

logger.info("TaskScheduler: Loaded %d measurement records from %s", total_ct + 1, log_file)


class TaskSchedulerCallback:
"""The base class of task scheduler callback functions. """

def pre_tune(self, task_scheduler, task_id):
"""The callback before tuning each task.

Parameters
----------
task_scheduler: TaskScheduler
The task scheduler.
task_id: int
The task ID going to be tuned.
"""
# Do nothing by default

def post_tune(self, task_scheduler, task_id):
"""The callback after tuning each task.

Parameters
----------
task_scheduler: TaskScheduler
The task scheduler.
task_id: int
The task ID be tuned.
"""
# Do nothing by default


class PrintTableInfo(TaskSchedulerCallback):
"""The callback that prints a table of current progress."""

def pre_tune(self, task_scheduler, task_id):
if task_scheduler.tune_option.verbose < 1:
return

_ffi_api.PrintTitle("Task Scheduler")
print("| ID | Latency (ms) | Speed (GFLOPS) | Trials |")
print("-------------------------------------------------")

# content
for i in range(len(task_scheduler.tasks)):
id_str = "%d" % i
latency_str = (
"%.3f" % (1e3 * task_scheduler.best_costs[i])
if task_scheduler.best_costs[i] < 1e9
else "-"
)
speed_str = (
"%.2f"
% (task_scheduler.tasks[i].compute_dag.flop_ct / task_scheduler.best_costs[i] / 1e9)
if task_scheduler.best_costs[i] < 1e9
else "-"
)
trials_str = "%d" % (task_scheduler.task_cts[i] * task_scheduler.num_measures_per_round)
print("| %4s | %12s | % 14s | %6s |" % (id_str, latency_str, speed_str, trials_str))
print("-------------------------------------------------")

# overall info
if all(cost < 1e9 for cost in task_scheduler.best_costs):
total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3)
else:
total_latency_str = "-"
print(
"Estimated total latency: %s ms\tTrials: %d\tUsed time : %.0f s\tNext ID: %d\t"
% (
total_latency_str,
task_scheduler.ct,
time.time() - task_scheduler.tic,
task_id,
)
)


class LogEstimatedLatency(TaskSchedulerCallback):
"""Log the estimated latency to the file after tuning a task.

Parameters
----------
log_file: str
The log file path.
"""

def __init__(self, log_file):
if os.path.exists(log_file): # Remove existing log
os.remove(log_file)

self.log_file = log_file

def post_tune(self, task_scheduler, task_id):
if all(cost < 1e9 for cost in task_scheduler.best_costs):
total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3)
else:
total_latency_str = "N/A"

with open(self.log_file, "a") as filep:
filep.write(
"ElapsedTime(s)\t%.0f\tEstimatedLatency(ms)\t%s\tTrials\t%d\n"
% (
time.time() - task_scheduler.tic,
total_latency_str,
task_scheduler.ct,
)
)
filep.flush()