diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 6e73ab13b1a5..fca889448180 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -43,40 +43,74 @@ @tvm._ffi.register_object("auto_scheduler.HardwareParams") class HardwareParams(Object): - """The parameters of target hardware used to guide the search policy + """The parameters of target hardware used to guide the search policy. + + When a parameter isn't provided, it will instead use the + current machine's default value if target is specified. TODO(jcf94): This is considered to be merged with the new Target specification: https://discuss.tvm.apache.org/t/rfc-tvm-target-specification/6844 Parameters ---------- - num_cores : int + num_cores : int, optional The number of device cores. - vector_unit_bytes : int + vector_unit_bytes : int, optional The width of vector units in bytes. - cache_line_bytes : int + cache_line_bytes : int, optional The size of cache line in bytes. - max_shared_memory_per_block : int + max_shared_memory_per_block : int, optional The max shared memory per block in bytes. - max_local_memory_per_block : int + max_local_memory_per_block : int, optional The max local memory per block in bytes. - max_threads_per_block : int + max_threads_per_block : int, optional The max number of threads per block. - max_vthread_extent : int + max_vthread_extent : int, optional The max vthread extent. - warp_size : int + warp_size : int, optional The thread numbers of a warp. + target : str or Target, optional + The compilation target. Used to determine default values if provided. + target_host : str or Target, optional + The compilation target host. Used to determine default values if provided. """ def __init__( self, - num_cores, - vector_unit_bytes, - cache_line_bytes, - max_shared_memory_per_block, - max_local_memory_per_block, - max_threads_per_block, - max_vthread_extent, - warp_size, + num_cores=None, + vector_unit_bytes=None, + cache_line_bytes=None, + max_shared_memory_per_block=None, + max_local_memory_per_block=None, + max_threads_per_block=None, + max_vthread_extent=None, + warp_size=None, + target=None, + target_host=None, ): + # If target is provided, get the default paramters for this machine. + if target is not None: + if isinstance(target, str): + target = tvm.target.Target(target) + if isinstance(target_host, str): + target_host = tvm.target.Target(target_host) + default_params = _ffi_api.GetDefaultHardwareParams(target, target_host) + + if num_cores is None: + num_cores = default_params.num_cores + if vector_unit_bytes is None: + vector_unit_bytes = default_params.vector_unit_bytes + if cache_line_bytes is None: + cache_line_bytes = default_params.cache_line_bytes + if max_shared_memory_per_block is None: + max_shared_memory_per_block = default_params.max_shared_memory_per_block + if max_local_memory_per_block is None: + max_local_memory_per_block = default_params.max_local_memory_per_block + if max_threads_per_block is None: + max_threads_per_block = default_params.max_threads_per_block + if max_vthread_extent is None: + max_vthread_extent = default_params.max_vthread_extent + if warp_size is None: + warp_size = default_params.warp_size + self.__init_handle_by_constructor__( _ffi_api.HardwareParams, num_cores, @@ -89,6 +123,21 @@ def __init__( warp_size, ) + def __str__(self): + """Pretty printing for hardware parameter configuration.""" + format_str = ( + "HardwareParams:\n" + f" num_cores: {self.num_cores}\n" + f" vector_unit_bytes: {self.vector_unit_bytes}\n" + f" cache_line_bytes: {self.cache_line_bytes}\n" + f" max_shared_memory_per_block: {self.max_shared_memory_per_block}\n" + f" max_local_memory_per_block: {self.max_local_memory_per_block}\n" + f" max_threads_per_block: {self.max_threads_per_block}\n" + f" max_vthread_extent: {self.max_vthread_extent}\n" + f" warp_size: {self.warp_size}\n" + ) + return format_str + @tvm._ffi.register_object("auto_scheduler.TuningOptions") class TuningOptions(Object): diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 0221870badcf..5cae556e2747 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -329,7 +329,10 @@ def tune( tune_option.num_measures_per_round, tune_option.num_measure_trials // len(self.tasks) ) if self.num_measures_per_round <= 0: - raise ValueError("num_measure_trials is too small. Please set it to a higher value.") + raise ValueError( + "num_measure_trials is too small. Please set it to a higher value." + f"It should be at least {len(self.tasks)} for this model." + ) # restore the status of the task scheduler from a log file if self.load_log_file: diff --git a/python/tvm/driver/tvmc/__init__.py b/python/tvm/driver/tvmc/__init__.py index d9c15792349a..42184c34df74 100644 --- a/python/tvm/driver/tvmc/__init__.py +++ b/python/tvm/driver/tvmc/__init__.py @@ -22,6 +22,9 @@ from . import autotuner from . import compiler from . import runner +from . import result_utils from .frontends import load_model as load from .compiler import compile_model as compile from .runner import run_module as run +from .autotuner import tune_model as tune +from .model import TVMCModel, TVMCPackage, TVMCResult diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index bdb4c6200b98..8f94c53045e5 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -20,10 +20,14 @@ import os.path import logging import time +from copy import deepcopy +from typing import Optional, Dict, List, Union from urllib.parse import urlparse +import tvm from tvm import autotvm, auto_scheduler +from tvm.auto_scheduler.search_task import HardwareParams from tvm.autotvm.tuner import GATuner from tvm.autotvm.tuner import GridSearchTuner from tvm.autotvm.tuner import RandomTuner @@ -33,6 +37,7 @@ from . import common, composite_target, frontends from .common import TVMCException from .main import register_parser +from .model import TVMCModel # pylint: disable=invalid-name @@ -52,7 +57,7 @@ def add_tune_parser(subparsers): ) # There is some extra processing required to define the actual default value - # for --min-repeat-ms. This is done in `drive_tune`. + # for --min-repeat-ms. This is done in `tune_model`. parser.add_argument( "--min-repeat-ms", default=None, @@ -93,7 +98,8 @@ def add_tune_parser(subparsers): ) parser.add_argument( "--rpc-key", - help="the RPC tracker key of the target device. Required when --rpc-tracker is provided.", + help="the RPC tracker key of the target device. " + "Required when --rpc-tracker is provided.", ) parser.add_argument( "--rpc-tracker", @@ -142,50 +148,50 @@ def add_tune_parser(subparsers): auto_scheduler_group.add_argument( "--cache-line-bytes", type=int, - default=64, - help="the size of cache line in bytes", + help="the size of cache line in bytes. " + "If not specified, it will be autoset for the current machine.", ) auto_scheduler_group.add_argument( "--num-cores", type=int, - default=4, - help="the number of device cores", + help="the number of device cores. " + "If not specified, it will be autoset for the current machine.", ) auto_scheduler_group.add_argument( "--vector-unit-bytes", type=int, - default=16, - help="the width of vector units in bytes", + help="the width of vector units in bytes. " + "If not specified, it will be autoset for the current machine.", ) auto_scheduler_group.add_argument( "--max-shared-memory-per-block", type=int, - default=0, - help="the max shared memory per block in bytes", + help="the max shared memory per block in bytes. " + "If not specified, it will be autoset for the current machine.", ) auto_scheduler_group.add_argument( "--max-local-memory-per-block", type=int, - default=0, - help="the max local memory per block in bytes", + help="the max local memory per block in bytes. " + "If not specified, it will be autoset for the current machine.", ) auto_scheduler_group.add_argument( "--max-threads-per-block", type=int, - default=0, - help="the max number of threads per block", + help="the max number of threads per block. " + "If not specified, it will be autoset for the current machine.", ) auto_scheduler_group.add_argument( "--max-vthread-extent", type=int, - default=0, - help="the max vthread extent", + help="the max vthread extent. " + "If not specified, it will be autoset for the current machine.", ) auto_scheduler_group.add_argument( "--warp-size", type=int, - default=0, - help="the thread numbers of a warp", + help="the thread numbers of a warp. " + "If not specified, it will be autoset for the current machine.", ) auto_scheduler_group.add_argument( "--include-simple-tasks", @@ -216,7 +222,6 @@ def add_tune_parser(subparsers): help="specify non-generic shapes for model to run, format is " '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"', type=common.parse_shape_string, - default=None, ) @@ -228,8 +233,22 @@ def drive_tune(args): args: argparse.Namespace Arguments from command line parser. """ - # extra arguments validation before importing the model, so that obvious errors - # are pointed in advance. + tvmc_model = frontends.load_model(args.FILE, args.model_format, shape_dict=args.input_shapes) + + # Specify hardware parameters, although they'll only be used if autoscheduling. + hardware_params = auto_scheduler.HardwareParams( + num_cores=args.num_cores, + vector_unit_bytes=args.vector_unit_bytes, + cache_line_bytes=args.cache_line_bytes, + max_shared_memory_per_block=args.max_shared_memory_per_block, + max_local_memory_per_block=args.max_local_memory_per_block, + max_threads_per_block=args.max_threads_per_block, + max_vthread_extent=args.max_vthread_extent, + warp_size=args.warp_size, + target=args.target, + target_host=args.target_host, + ) + if args.rpc_tracker: parsed_url = urlparse("//%s" % args.rpc_tracker) rpc_hostname = parsed_url.hostname @@ -241,11 +260,127 @@ def drive_tune(args): raise common.TVMCException( "need to provide an RPC tracker key (--rpc-key) for remote tuning" ) + else: + rpc_host_name = None + rpc_port = None + + tune_model( + tvmc_model, + args.target, + tuning_records=args.output, + prior_records=args.tuning_records, + enable_autoscheduler=args.enable_autoscheduler, + rpc_key=args.rpc_key, + hostname=rpc_host_name, + port=rpc_port, + trials=args.trials, + target_host=args.target_host, + tuner=args.tuner, + min_repeat_ms=args.min_repeat_ms, + early_stopping=args.early_stopping, + desired_layout=args.desired_layout, + timeout=args.timeout, + repeat=args.repeat, + number=args.number, + parallel=args.parallel, + hardware_params=hardware_params, + include_simple_tasks=args.include_simple_tasks, + log_estimated_latency=args.log_estimated_latency, + ) + + +def tune_model( + tvmc_model: TVMCModel, + target: str, + tuning_records: Optional[str] = None, + prior_records: Optional[str] = None, + enable_autoscheduler: bool = False, + rpc_key: Optional[str] = None, + hostname: Optional[str] = None, + port: Optional[Union[int, str]] = 9090, + trials: int = 10000, + target_host: Optional[str] = None, + tuner: str = "xgb", + min_repeat_ms: Optional[int] = None, + early_stopping: Optional[int] = None, + desired_layout: Optional[str] = None, + timeout: int = 10, + repeat: int = 1, + number: int = 10, + parallel: int = 4, + hardware_params: Optional[HardwareParams] = None, + include_simple_tasks: bool = False, + log_estimated_latency: bool = False, +): + """Use tuning to automatically optimize the functions in a model. + + Parameters + ---------- + tvmc_model : TVMCModel + The model to be optimized. + target : str + Compilation target as plain string, inline JSON or path to a JSON file. + tuning_records: str, optional + The path to a file that tuning results will be saved to. If not specified, + a temporary file will be used. + prior_records: str, optional + A path to previous tuning results that will be used to hot-start the tuning + cost model if provided. + enable_autoscheduler : bool, optional + When true, use autoscheduling rather than autotvm. This should produce + faster kernels for compatible model-target pairs. + rpc_key : str, optional + The RPC tracker key of the target device. Required when rpc_tracker is provided. + host_name : str, optional + The IP address of an RPC tracker, used when benchmarking remotely. + port : int or str, optional + The port of the RPC tracker to connect to. Defaults to 9090. + trials : int, optional + The number of schedules to try out for the entire model. Note that the default + value is chosen as a decent average for most models, but larger models may need + more trials to reach a good result while smaller models will converge with fewer + trials. + tuner : str, optional + The type of tuner to use when tuning with autotvm. Can be one of + "ga", "gridsearch", "random", "xgb", "xgb_knob", and "xgb-rank". + min_repeat_ms : int, optional + Minimum time to run each trial. Defaults to 0 on x86 and 1000 on other targets. + early_stopping : int, optional + When specified, stop tuning after this number of trials if results aren't improving. + desired_layout : str, optional + Can be one of "NCHW" or "NHWC". When specified, compatible operations in the graph + will have their layout set to this format. Tasks will then be tuned using this + specified layout. + timeout : int, optional, + If a kernel trial lasts longer than this duration in seconds, it will be + considered a failure. + repeat : int, optional + How many times each measurement should be repeated. + number : int, optional + The number of runs a single repeat is made of. + parallel : int, optional + The maximum number of parallel devices to use when tuning. + hardware_params : auto_scheduler.HardwareParams, optional + When using the autoscheduler, this object defines the configuration of the target hardware. + include_simple_tasks : bool, optional + Whether to extract simple operations or only computationally intensive ones when using + the autoscheduler. + log_estimated_latency : bool, optional + If using the autoscheduler, write the estimated latency at each step of tuning to file. - target, extra_targets = common.target_from_cli(args.target) - target_host = args.target_host + Returns + ------- + tuning_records : str + The path to the produced tuning log file. + """ + target, extra_targets = common.target_from_cli(target) target, target_host = Target.check_and_update_host_consist(target, target_host) - mod, params = frontends.load_model(args.FILE, args.model_format, shape_dict=args.input_shapes) + # TODO(jwfromm) Remove this deepcopy once AlterOpLayout bug that mutates source + # model is fixed. For now, creating a clone avoids the issue. + mod = deepcopy(tvmc_model.mod) + params = tvmc_model.params + if tuning_records is None: + tuning_records = tvmc_model.default_tuning_records_path() for codegen_from_cli in extra_targets: codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"]) @@ -255,97 +390,113 @@ def drive_tune(args): # min_repeat_ms should be: # a. the value provided by the user, if any, or # b. 0ms in case target is "cpu"; otherwise 1000ms - if args.min_repeat_ms is not None: - min_repeat_ms = args.min_repeat_ms - else: + if min_repeat_ms is None: min_repeat_ms = 0 if target.keys[0] == "cpu" else 1000 - logger.debug("Default --min-repeat-ms for this target is %s", min_repeat_ms) + logger.info("Default --min-repeat-ms for this target is %s", min_repeat_ms) - if args.rpc_tracker: - runner_ctor = auto_scheduler.RPCRunner if args.enable_autoscheduler else autotvm.RPCRunner + if rpc_key: + if hostname is None or port is None: + raise common.TVMCException( + "You must provide a hostname and port to connect to a remote RPC device." + ) + if isinstance(port, str): + port = int(port) + + logger.info("Tuning will be performed on device %s at %s:%d.", rpc_key, hostname, port) + + runner_ctor = auto_scheduler.RPCRunner if enable_autoscheduler else autotvm.RPCRunner runner = runner_ctor( - key=args.rpc_key, - host=rpc_hostname, - port=rpc_port, - number=args.number, - repeat=args.repeat, - n_parallel=args.parallel, - timeout=args.timeout, + key=rpc_key, + host=hostname, + port=port, + number=number, + repeat=repeat, + n_parallel=parallel, + timeout=timeout, min_repeat_ms=min_repeat_ms, ) else: - logger.info("starting localhost tuning") + logger.info("Starting localhost tuning.") runner_ctor = ( - auto_scheduler.LocalRunner if args.enable_autoscheduler else autotvm.LocalRunner + auto_scheduler.LocalRPCMeasureContext if enable_autoscheduler else autotvm.LocalRunner ) - runner = runner_ctor( - number=args.number, - repeat=args.repeat, - timeout=args.timeout, + local_server = runner_ctor( + number=number, + repeat=repeat, + timeout=timeout, min_repeat_ms=min_repeat_ms, ) - if args.enable_autoscheduler: - # Specify hardware parameters - hardware_params = auto_scheduler.HardwareParams( - args.num_cores, - args.vector_unit_bytes, - args.cache_line_bytes, - args.max_shared_memory_per_block, - args.max_local_memory_per_block, - args.max_threads_per_block, - args.max_vthread_extent, - args.warp_size, - ) + # For autoscheduling on some devices, we need to maintain a LocalRPCMeasureContext object. + if enable_autoscheduler: + runner = local_server.runner + else: + runner = local_server + + if enable_autoscheduler: + tasks, weights = autoscheduler_get_tuning_tasks( mod=mod, params=params, target=target, - alter_layout=args.desired_layout, + alter_layout=desired_layout, hardware_params=hardware_params, - include_simple_tasks=args.include_simple_tasks, + include_simple_tasks=include_simple_tasks, ) # Create the autoscheduler tuning options tuning_options = auto_scheduler.TuningOptions( - num_measure_trials=args.trials, - measure_callbacks=[auto_scheduler.RecordToFile(args.output)], + num_measure_trials=trials, + measure_callbacks=[auto_scheduler.RecordToFile(tuning_records)], runner=runner, - early_stopping=args.early_stopping, + early_stopping=early_stopping, ) + logger.info("Autoscheduling with configuration: %s", tuning_options) + # Schedule the tasks (i.e., produce a schedule for each task) - schedule_tasks( - tasks, weights, tuning_options, args.tuning_records, args.log_estimated_latency - ) + schedule_tasks(tasks, weights, tuning_options, prior_records, log_estimated_latency) else: tasks = autotvm_get_tuning_tasks( mod=mod, params=params, target=target, - alter_layout=args.desired_layout, + alter_layout=desired_layout, ) - tuning_option = { - "tuner": args.tuner, - "trials": args.trials, - "early_stopping": args.early_stopping, + # In autotvm, trials is specified per task. We can convert the per-model input + # provided to per-task trials by dividing by the number of tasks. + trials = int(trials / len(tasks)) + logger.info("Autotuning with %d trials per task.", trials) + + tuning_options = { + "tuner": tuner, + "trials": trials, + "early_stopping": early_stopping, "measure_option": autotvm.measure_option( builder=autotvm.LocalBuilder(build_func="default"), runner=runner ), - "tuning_records": args.tuning_records, + "tuning_records": prior_records, } - logger.debug(" tuning options: %s", tuning_option) + logger.info("Autotuning with configuration: %s", tuning_options) - tune_tasks(tasks, args.output, **tuning_option) + tune_tasks(tasks, tuning_records, **tuning_options) + return tuning_records -def autotvm_get_tuning_tasks(mod, params, target, target_host=None, alter_layout=None): + +def autotvm_get_tuning_tasks( + mod: tvm.IRModule, + params: Dict[str, tvm.nd.NDArray], + target: str, + target_host: Optional[str] = None, + alter_layout: Optional[str] = None, +): """Get the autotvm tuning tasks for a given relay module. Parameters ---------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module from which to extract tuning tasks. params : dict The params for the relay module. @@ -378,19 +529,19 @@ def autotvm_get_tuning_tasks(mod, params, target, target_host=None, alter_layout def autoscheduler_get_tuning_tasks( - mod, - params, - target, - target_host=None, - alter_layout=None, - hardware_params=None, - include_simple_tasks=False, + mod: tvm.IRModule, + params: Dict[str, tvm.nd.NDArray], + target: str, + target_host: Optional[str] = None, + alter_layout: Optional[str] = None, + hardware_params: Optional[HardwareParams] = None, + include_simple_tasks: bool = False, ): """Get the autoscheduler tuning tasks for a given relay module. Parameters ---------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module from which to extract tuning tasks. params : dict The params for the relay module. @@ -430,7 +581,11 @@ def autoscheduler_get_tuning_tasks( def schedule_tasks( - tasks, task_weights, tuning_options, tuning_records=None, log_estimated_latency=False + tasks: List[auto_scheduler.SearchTask], + task_weights: List[float], + tuning_options: auto_scheduler.TuningOptions, + prior_records: Optional[str] = None, + log_estimated_latency: bool = False, ): """Generate the schedules for the different tasks (i.e., subgraphs) contained in the module. Store the schedules in a json file that will be used later by the compiler. @@ -441,10 +596,12 @@ def schedule_tasks( A list of auto_scheduler.SearchTask to tune. task_weights : list The weight (i.e. the number of appearance) of extracted tasks - tuning_options: dict + tuning_options: auto_scheduler.TuningOptions The options of tuning - tuning_records : str, optional + prior_records : str, optional The json file used to preload the autoscheduler + log_estimated_latency : bool, optional + If true, writes the estimated runtime of the model during each step of tuning to file. """ if not log_estimated_latency: callbacks = [auto_scheduler.task_scheduler.PrintTableInfo()] @@ -456,7 +613,7 @@ def schedule_tasks( # Create the scheduler tuner = auto_scheduler.TaskScheduler( - tasks, task_weights, load_log_file=tuning_records, callbacks=callbacks + tasks, task_weights, load_log_file=prior_records, callbacks=callbacks ) # Tune the tasks @@ -464,13 +621,13 @@ def schedule_tasks( def tune_tasks( - tasks, - log_file, - measure_option, - tuner, - trials, - early_stopping=None, - tuning_records=None, + tasks: List[autotvm.task.Task], + log_file: str, + measure_option: autotvm.measure_option, + tuner: str, + trials: int, + early_stopping: Optional[int] = None, + tuning_records: Optional[str] = None, ): """Tune a list of tasks and output the history to a log file. diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 77ba1cb47cc8..34f59aac9712 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -59,6 +59,7 @@ def convert_graph_layout(mod, desired_layout): # conv2d as heavily-sensitive operators. desired_layouts = { "nn.conv2d": [desired_layout, "default"], + "nn.conv2d_transpose": [desired_layout, "default"], "qnn.conv2d": [desired_layout, "default"], } @@ -99,8 +100,8 @@ def validate_targets(parse_targets): if len(tvm_targets) > 1: verbose_tvm_targets = ", ".join(tvm_targets) raise TVMCException( - f"Only one of the following targets can be used at a time. " - "Found: {verbose_tvm_targets}." + "Only one of the following targets can be used at a time. " + f"Found: {verbose_tvm_targets}." ) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 6884c30049d1..3f1d04aee7fd 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -19,17 +19,16 @@ """ import logging import os.path -import tarfile +from typing import Optional, Dict, List, Union, Callable from pathlib import Path import tvm from tvm import autotvm, auto_scheduler -from tvm import relay, runtime -from tvm.contrib import cc -from tvm.contrib import utils +from tvm import relay from tvm.target import Target from . import common, composite_target, frontends +from .model import TVMCModel, TVMCPackage from .main import register_parser @@ -96,7 +95,7 @@ def add_compile_parser(subparsers): default=None, ) parser.add_argument( - "--disable-pass", + "--disabled-pass", help="disable specific passes, comma-separated list of pass names", type=common.parse_pass_list_str, default="", @@ -117,35 +116,36 @@ def drive_compile(args): Zero if successfully completed """ - mod, params = frontends.load_model(args.FILE, args.model_format, args.input_shapes) + tvmc_model = frontends.load_model(args.FILE, args.model_format, args.input_shapes) - graph, lib, params, dumps = compile_model( - mod, - params, + dump_code = [x.strip() for x in args.dump_code.split(",")] if args.dump_code else None + + compile_model( + tvmc_model, args.target, - args.dump_code, - None, - args.tuning_records, - args.desired_layout, - args.disable_pass, + tuning_records=args.tuning_records, + package_path=args.output, + cross=args.cross_compiler, + dump_code=dump_code, + target_host=None, + desired_layout=args.desired_layout, + disabled_pass=args.disabled_pass, ) - if dumps: - save_dumps(args.output, dumps) - - save_module(args.output, graph, lib, params, args.cross_compiler) return 0 def compile_model( - mod, - params, - target, - dump_code=None, - target_host=None, - tuning_records=None, - alter_layout=None, - disabled_pass=None, + tvmc_model: TVMCModel, + target: str, + tuning_records: Optional[str] = None, + package_path: Optional[str] = None, + cross: Optional[Union[str, Callable]] = None, + export_format: str = "so", + dump_code: Optional[List[str]] = None, + target_host: Optional[str] = None, + desired_layout: Optional[str] = None, + disabled_pass: Optional[str] = None, ): """Compile a model from a supported framework into a TVM module. @@ -155,23 +155,29 @@ def compile_model( Parameters ---------- - mod: IRModule - The relay module to be compiled. - params: dict - A dictionary containing the module's parameters. + tvmc_model : TVMCModel + The model object that should be compiled. target : str The target for which to compile. Can be a plain string or a path. + tuning_records : str + A path to tuning records produced using tvmc.tune. When provided, + compilation will use more optimized kernels leading to better results. + package_path : str, optional + The path to export the compiled model to. If not provided it will + be saved in a temporary directory. + cross : str or callable object, optional + Function that performs the actual compilation + export_format : str + What format to use when saving the function library. Must be one of "so" or "tar". + When compiling for a remote device without a cross compiler, "tar" will likely work better. dump_code : list, optional Dump the generated code for the specified source types, on the requested target. target_host : str, optional The target of the host machine if host-side code needs to be generated. - tuning_records: str, optional - Path to the file produced by the tuning to be used during - compilation. - alter_layout: str, optional + desired_layout: str, optional The layout to convert the graph to. Note, the convert layout pass doesn't currently guarantee the whole of the graph will be converted to the chosen layout. @@ -182,24 +188,18 @@ def compile_model( Returns ------- - graph : str - A JSON-serialized TVM execution graph. - lib : tvm.module.Module - A TVM module containing the compiled functions. - params : dict - The parameters (weights) for the TVM module. - dumps : dict - Dictionary containing the dumps specified. + compiled_model : TVMCPackage + The compiled TVMCModel ready to be run. """ - dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else None + mod, params = tvmc_model.mod, tvmc_model.params + config = {} - if alter_layout: - mod = common.convert_graph_layout(mod, alter_layout) + if desired_layout: + mod = common.convert_graph_layout(mod, desired_layout) tvm_target, extra_targets = common.target_from_cli(target) - target_host = tvm_target if not target_host else target_host tvm_target, target_host = Target.check_and_update_host_consist(tvm_target, target_host) for codegen_from_cli in extra_targets: @@ -225,21 +225,24 @@ def compile_model( opt_level=3, config=config, disabled_pass=disabled_pass ): logger.debug("building relay graph with autoscheduler") - graph_module = relay.build(mod, target=target, params=params) + graph_module = relay.build(mod, target=tvm_target, params=params) else: with autotvm.apply_history_best(tuning_records): with tvm.transform.PassContext( opt_level=3, config=config, disabled_pass=disabled_pass ): logger.debug("building relay graph with tuning records") - graph_module = relay.build(mod, tvm_target, params=params) + graph_module = relay.build(mod, target=tvm_target, params=params) else: with tvm.transform.PassContext(opt_level=3, config=config, disabled_pass=disabled_pass): logger.debug("building relay graph (no tuning records provided)") - graph_module = relay.build(mod, tvm_target, params=params) + graph_module = relay.build(mod, target=tvm_target, params=params) # Generate output dump files with sources - dump_code = dump_code or [] + if dump_code is None: + dump_code = [] + if not isinstance(dump_code, list): + dump_code = [dump_code] dumps = {} for source_type in dump_code: lib = graph_module.get_lib() @@ -248,59 +251,17 @@ def compile_model( source = str(mod) if source_type == "relay" else lib.get_source(source_type) dumps[source_type] = source - # TODO we need to update this return to use the updated graph module APIs - # as these getter functions will be deprecated in the next release (@leandron) - return graph_module.get_json(), graph_module.get_lib(), graph_module.get_params(), dumps - + # Create a new tvmc model package object from the graph definition. + package_path = tvmc_model.export_package(graph_module, package_path, cross, export_format) -def save_module(module_path, graph, lib, params, cross=None): - """ - Create a tarball containing the generated TVM graph, - exported library and parameters - - Parameters - ---------- - module_path : str - path to the target tar.gz file to be created, - including the file name - graph : str - A JSON-serialized TVM execution graph. - lib : tvm.module.Module - A TVM module containing the compiled functions. - params : dict - The parameters (weights) for the TVM module. - cross : str or callable object, optional - Function that performs the actual compilation - - """ - lib_name = "mod.so" - graph_name = "mod.json" - param_name = "mod.params" - temp = utils.tempdir() - path_lib = temp.relpath(lib_name) - if not cross: - logger.debug("exporting library to %s", path_lib) - lib.export_library(path_lib) - else: - logger.debug("exporting library to %s , using cross compiler %s", path_lib, cross) - lib.export_library(path_lib, cc.cross_compiler(cross)) - - with open(temp.relpath(graph_name), "w") as graph_file: - logger.debug("writing graph to file to %s", graph_file.name) - graph_file.write(graph) - - with open(temp.relpath(param_name), "wb") as params_file: - logger.debug("writing params to file to %s", params_file.name) - params_file.write(runtime.save_param_dict(params)) + # Write dumps to file. + if dumps: + save_dumps(package_path, dumps) - logger.debug("saving module as tar file to %s", module_path) - with tarfile.open(module_path, "w") as tar: - tar.add(path_lib, lib_name) - tar.add(temp.relpath(graph_name), graph_name) - tar.add(temp.relpath(param_name), param_name) + return TVMCPackage(package_path) -def save_dumps(module_name, dumps, dump_root="."): +def save_dumps(module_name: str, dumps: Dict[str, str], dump_root: str = "."): """ Serialize dump files to the disk. @@ -313,7 +274,6 @@ def save_dumps(module_name, dumps, dump_root="."): The output contents to be saved into the files dump_root : str, optional Path in which dump files will be created - """ for dump_format in dumps: diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 0488223c782f..89ca1b8fc329 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -25,12 +25,14 @@ import sys from abc import ABC from abc import abstractmethod +from typing import Optional, List, Dict from pathlib import Path import numpy as np from tvm import relay from tvm.driver.tvmc.common import TVMCException +from tvm.driver.tvmc.model import TVMCModel # pylint: disable=invalid-name @@ -284,7 +286,7 @@ def get_frontend_names(): return [frontend.name() for frontend in ALL_FRONTENDS] -def get_frontend_by_name(name): +def get_frontend_by_name(name: str): """ This function will try to get a frontend instance, based on the name provided. @@ -311,7 +313,7 @@ def get_frontend_by_name(name): ) -def guess_frontend(path): +def guess_frontend(path: str): """ This function will try to imply which framework is being used, based on the extension of the file provided in the path parameter. @@ -340,7 +342,12 @@ def guess_frontend(path): raise TVMCException("failed to infer the model format. Please specify --model-format") -def load_model(path, model_format=None, shape_dict=None, **kwargs): +def load_model( + path: str, + model_format: Optional[str] = None, + shape_dict: Optional[Dict[str, List[int]]] = None, + **kwargs, +): """Load a model from a supported framework and convert it into an equivalent relay representation. @@ -356,10 +363,8 @@ def load_model(path, model_format=None, shape_dict=None, **kwargs): Returns ------- - mod : tvm.relay.Module - The produced relay module. - params : dict - The parameters (weights) for the relay module. + tvmc_model : TVMCModel + The produced model package. """ @@ -370,4 +375,4 @@ def load_model(path, model_format=None, shape_dict=None, **kwargs): mod, params = frontend.load(path, shape_dict, **kwargs) - return mod, params + return TVMCModel(mod, params) diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py new file mode 100644 index 000000000000..a26a47c788fe --- /dev/null +++ b/python/tvm/driver/tvmc/model.py @@ -0,0 +1,364 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This file contains the definition of a set of classes that wrap the outputs +of TVMC functions to create a simpler and more intuitive API. + +There is one class for each required stage of a TVM workflow. +The TVMCModel represents the result of importing a model into TVM, it +contains the precompiled graph definition and parameters that define +what the model does. + +Compiling a TVMCModel produces a TVMCPackage, which contains the generated +artifacts that allow the model to be run on the target hardware. + +Running a TVMCPackage produces a TVMCResult, which contains the outputs of +the model and the measured runtime. + +Examples +-------- +The following code shows a full lifecycle for a model using tvmc, first the +model is imported from an exterior framework, in this case onnx, then it +is tuned to find the best schedules on CPU, then compiled into a TVMCPackage, +and finally run. + +.. code-block:: python + tvmc_model = tvmc.load("my_model.onnx") + tuning_records = tvmc.tune(tvmc_model, target="llvm") + tvmc_package = tvmc.compile(tvmc_model, target="llvm", tuning_records=tuning_records) + result = tvmc.run(tvmc_package, device="cpu") + print(result) +""" +import os +import tarfile +from typing import Optional, Union, List, Dict, Callable, TextIO +import numpy as np + +import tvm +import tvm.contrib.cc +from tvm import relay +from tvm.contrib import utils +from tvm.relay.backend.graph_executor_factory import GraphExecutorFactoryModule + +from .common import TVMCException + + +class TVMCModel(object): + """Initialize a TVMC model from a relay model definition or a saved file. + + Parameters + ---------- + mod : tvm.IRModule, optional + The relay module corresponding to this model. + params : dict, optional + A parameter dictionary for the model. + model_path: str, optional + An alternative way to load a TVMCModel, the path to a previously + saved model. + """ + + def __init__( + self, + mod: Optional[tvm.IRModule] = None, + params: Optional[Dict[str, tvm.nd.NDArray]] = None, + model_path: Optional[str] = None, + ): + if (mod is None or params is None) and (model_path is None): + raise TVMCException( + "Either mod and params must be provided " + "or a path to a previously saved TVMCModel" + ) + self._tmp_dir = utils.tempdir() + if model_path is not None: + self.load(model_path) + else: + self.mod = mod + self.params = params if params else {} + + def save(self, model_path: str): + """Save the TVMCModel to disk. + + Note that this saves the graph representation, + the parameters, and the tuning records if applicable. It will not save any + compiled artifacts. + + Parameters + ---------- + model_path : str + A full path to save this TVMCModel to including the output file name. + The file will be saved as a tar file so using a ".tar" extension is advised. + """ + temp = self._tmp_dir + + # Save relay graph + relay_name = "model.json" + relay_path = temp.relpath(relay_name) + with open(relay_path, "w") as relay_file: + relay_file.write(tvm.ir.save_json(self.mod)) + + # Save params + params_name = "model.params" + params_path = temp.relpath(params_name) + with open(params_path, "wb") as params_file: + params_file.write(relay.save_param_dict(self.params)) + + # Create a tar file. + with tarfile.open(model_path, "w") as tar: + tar.add(relay_path, relay_name) + tar.add(params_path, params_name) + # If default tuning records exist, save them as well. + if os.path.exists(self.default_tuning_records_path()): + tar.add(self.default_tuning_records_path(), "tuning_records") + # Also save the compiled package if it can be found. + if os.path.exists(self.default_package_path()): + tar.add(self.default_package_path(), "model_package.tar") + + def load(self, model_path: str): + """Load a TVMCModel from disk. + + Parameters + ---------- + model_path : str + A path to load the TVMCModel from. + """ + temp = self._tmp_dir + t = tarfile.open(model_path) + t.extractall(temp.relpath(".")) + + # Load relay IR. + relay_path = temp.relpath("model.json") + with open(relay_path, "r") as relay_file: + self.mod = tvm.ir.load_json(relay_file.read()) + + # Load parameter dictionary. + params_path = temp.relpath("model.params") + with open(params_path, "rb") as params_file: + self.params = relay.load_param_dict(params_file.read()) + + def default_tuning_records_path(self): + """Get a full path for storing tuning records in this model's temporary direcotry + + Note that when this path is used, the tuning records will be saved and loaded + when calling `save` and `load`. + + Returns + ------- + records_path: str + A path to the default location for tuning records. + """ + return self._tmp_dir.relpath("tuning_records") + + def default_package_path(self): + """Get a full path for storing a compiled package in this model's temporary direcotry + + Note that when this path is used, the package will be saved and loaded + when calling `save` and `load`. + + Returns + ------- + records_path: str + A path to the default location for tuning records. + """ + return self._tmp_dir.relpath("model_package.tar") + + def export_package( + self, + executor_factory: GraphExecutorFactoryModule, + package_path: Optional[str] = None, + cross: Optional[Union[str, Callable]] = None, + lib_format: str = "so", + ): + """Save this TVMCModel to file. + Parameters + ---------- + executor_factory : GraphExecutorFactoryModule + The factory containing compiled the compiled artifacts needed to run this model. + package_path : str, None + Where the model should be saved. Note that it will be packaged as a .tar file. + If not provided, the package will be saved to a generically named file in tmp. + cross : str or callable object, optional + Function that performs the actual compilation. + lib_format : str + How to export the modules function library. Must be one of "so" or "tar". + + Returns + ------- + package_path : str + The path that the package was saved to. + """ + if lib_format not in ["so", "tar"]: + raise TVMCException("Only .so and .tar export formats are supported.") + lib_name = "mod." + lib_format + graph_name = "mod.json" + param_name = "mod.params" + + temp = self._tmp_dir + if package_path is None: + package_path = self.default_package_path() + path_lib = temp.relpath(lib_name) + + if not cross: + executor_factory.get_lib().export_library(path_lib) + else: + executor_factory.get_lib().export_library( + path_lib, tvm.contrib.cc.cross_compiler(cross) + ) + self.lib_path = path_lib + + with open(temp.relpath(graph_name), "w") as graph_file: + graph_file.write(executor_factory.get_json()) + + with open(temp.relpath(param_name), "wb") as params_file: + params_file.write(relay.save_param_dict(executor_factory.get_params())) + + # Package up all the temp files into a tar file. + with tarfile.open(package_path, "w") as tar: + tar.add(path_lib, lib_name) + tar.add(temp.relpath(graph_name), graph_name) + tar.add(temp.relpath(param_name), param_name) + + return package_path + + def summary(self, file: TextIO = None): + """Print the IR corressponding to this model. + + Arguments + --------- + file: Writable, optional + If specified, the summary will be written to this file. + """ + print(self.mod, file=file) + + +class TVMCPackage(object): + """Load a saved TVMCPackage from disk. + + Parameters + ---------- + package_path : str + The path to the saved TVMCPackage that will be loaded. + """ + + def __init__(self, package_path: str): + self._tmp_dir = utils.tempdir() + self.package_path = package_path + self.import_package(self.package_path) + + def import_package(self, package_path: str): + """Load a TVMCPackage from a previously exported TVMCModel. + + Parameters + ---------- + package_path : str + The path to the saved TVMCPackage. + """ + lib_name_so = "mod.so" + lib_name_tar = "mod.tar" + graph_name = "mod.json" + param_name = "mod.params" + + temp = self._tmp_dir + t = tarfile.open(package_path) + t.extractall(temp.relpath(".")) + + with open(temp.relpath(param_name), "rb") as param_file: + self.params = bytearray(param_file.read()) + self.graph = open(temp.relpath(graph_name)).read() + if os.path.exists(temp.relpath(lib_name_so)): + self.lib_name = lib_name_so + elif os.path.exists(temp.relpath(lib_name_tar)): + self.lib_name = lib_name_tar + else: + raise TVMCException("Couldn't find exported library in the package.") + self.lib_path = temp.relpath(self.lib_name) + + +class TVMCResult(object): + """A class that stores the results of tvmc.run and provides helper utilities.""" + + def __init__(self, outputs: Dict[str, np.ndarray], times: List[str]): + """Create a convenience wrapper around the output of tvmc.run + + Parameters + ---------- + outputs : dict + Outputs dictionary mapping the name of the output to its numpy value. + times : list of float + The execution times measured by the time evaluator in seconds to produce outputs. + """ + self.outputs = outputs + self.times = times + + def format_times(self): + """Format the mean, max, min and std of the execution times. + + This has the effect of producing a small table that looks like: + .. code-block:: + Execution time summary: + mean (ms) max (ms) min (ms) std (ms) + 0.14310 0.16161 0.12933 0.01004 + + Returns + ------- + str + A formatted string containing the statistics. + """ + + # timestamps + mean_ts = np.mean(self.times) * 1000 + std_ts = np.std(self.times) * 1000 + max_ts = np.max(self.times) * 1000 + min_ts = np.min(self.times) * 1000 + + header = "Execution time summary:\n{0:^10} {1:^10} {2:^10} {3:^10}".format( + "mean (ms)", "max (ms)", "min (ms)", "std (ms)" + ) + stats = "{0:^10.2f} {1:^10.2f} {2:^10.2f} {3:^10.2f}".format( + mean_ts, max_ts, min_ts, std_ts + ) + + return "%s\n%s\n" % (header, stats) + + def get_output(self, name: str): + """A helper function to grab one of the outputs by name. + + Parameters + ---------- + name : str + The name of the output to return + + Returns + ------- + output : np.ndarray + The output corresponding to name. + """ + return self.outputs[name] + + def save(self, output_path: str): + """Save the numpy outputs to disk as a .npz file. + + Parameters + ---------- + output_path : str + The path to save the numpy results to. + """ + np.savez(output_path, **self.outputs) + + def __str__(self): + stat_table = self.format_times() + output_keys = f"Output Names:\n {list(self.outputs.keys())}" + return stat_table + "\n" + output_keys diff --git a/python/tvm/driver/tvmc/result_utils.py b/python/tvm/driver/tvmc/result_utils.py new file mode 100644 index 000000000000..10d3159c8969 --- /dev/null +++ b/python/tvm/driver/tvmc/result_utils.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This file contains utility functions for processing the outputs +of TVMC models. These utilities are likely to be task specific, +overtime more will be added to support more machine learning tasks. + +Examples +-------- +The following code shows how one might postprocess +the output of a classification model. + +.. code-block:: python + result = tvmc.run(tvmc_package, device="cpu") + top_results = result_utils.get_top_results(max_results=5) +""" +import numpy as np +from .model import TVMCResult + + +def get_top_results(result: TVMCResult, max_results: int): + """Return the top n results from the output tensor. + + This function is primarily for image classification and will + not necessarily generalize. + + Parameters + ---------- + result : TVMCResult + The output of a TVMCModel + max_results : int + Number of results to return + + Returns + ------- + top_results : np.array + Results array of shape (2, n). + The first row is the indices and the second is the values. + + """ + output = np.copy(result.outputs["output_0"]) + sorted_labels = output.argsort()[0][-max_results:][::-1] + output.sort() + sorted_values = output[0][-max_results:][::-1] + top_results = np.array([sorted_labels, sorted_values]) + return top_results diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index e1a21b2c2842..b15a16a9fd38 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -19,20 +19,21 @@ """ import json import logging -import os -import tarfile -import tempfile +from typing import Optional, Dict, List, Union import numpy as np +import tvm from tvm import rpc from tvm.autotvm.measure import request_remote from tvm.contrib import graph_executor as runtime from tvm.contrib.debugger import debug_executor -from tvm.relay import load_param_dict +from tvm.relay.param_dict import load_param_dict from . import common +from .model import TVMCPackage, TVMCResult from .common import TVMCException from .main import register_parser +from .result_utils import get_top_results # pylint: disable=invalid-name @@ -82,7 +83,10 @@ def add_run_parser(subparsers): "making it take longer to be generated.", ) parser.add_argument( - "--repeat", metavar="N", type=int, default=1, help="repeat the run n times. Defaults to '1'" + "--repeat", metavar="N", type=int, default=1, help="run the model n times. Defaults to '1'" + ) + parser.add_argument( + "--number", metavar="N", type=int, default=1, help="repeat the run n times. Defaults to '1'" ) parser.add_argument( "--rpc-key", @@ -112,8 +116,10 @@ def drive_run(args): except IOError as ex: raise TVMCException("Error loading inputs file: %s" % ex) - outputs, times = run_module( - args.FILE, + tvmc_package = TVMCPackage(package_path=args.FILE) + + result = run_module( + tvmc_package, args.device, hostname=rpc_hostname, port=rpc_port, @@ -121,25 +127,26 @@ def drive_run(args): inputs=inputs, fill_mode=args.fill_mode, repeat=args.repeat, + number=args.number, profile=args.profile, ) if args.print_time: - stat_table = format_times(times) + stat_table = result.format_times() # print here is intentional print(stat_table) if args.print_top: - top_results = get_top_results(outputs, args.print_top) + top_results = get_top_results(result, args.print_top) # print here is intentional print(top_results) if args.outputs: # Save the outputs - np.savez(args.outputs, **outputs) + result.save(args.outputs) -def get_input_info(graph_str, params): +def get_input_info(graph_str: str, params: Dict[str, tvm.nd.NDArray]): """Return the 'shape' and 'dtype' dictionaries for the input tensors of a compiled module. @@ -155,8 +162,8 @@ def get_input_info(graph_str, params): ---------- graph_str : str JSON graph of the module serialized as a string. - params : bytearray - Params serialized as a bytearray. + params : dict + Parameter dictionary mapping name to value. Returns ------- @@ -179,14 +186,14 @@ def get_input_info(graph_str, params): shape_dict[name] = graph["attrs"]["shape"][1][node_id] dtype_dict[name] = graph["attrs"]["dltype"][1][node_id] - logger.debug("collecting graph input shape and type:") - logger.debug("graph input shape: %s", shape_dict) - logger.debug("graph input type: %s", dtype_dict) + logger.debug("Collecting graph input shape and type:") + logger.debug("Graph input shape: %s", shape_dict) + logger.debug("Graph input type: %s", dtype_dict) return shape_dict, dtype_dict -def generate_tensor_data(shape, dtype, fill_mode): +def generate_tensor_data(shape: tuple, dtype: str, fill_mode: str): """Generate data to produce a tensor of given shape and dtype. Random data generation depends on the dtype. For int8 types, @@ -226,7 +233,12 @@ def generate_tensor_data(shape, dtype, fill_mode): return tensor -def make_inputs_dict(shape_dict, dtype_dict, inputs=None, fill_mode="random"): +def make_inputs_dict( + shape_dict: Dict[str, List[int]], + dtype_dict: Dict[str, str], + inputs: Optional[Dict[str, np.ndarray]] = None, + fill_mode: str = "random", +): """Make the inputs dictionary for a graph. Use data from 'inputs' where specified. For input tensors @@ -289,15 +301,16 @@ def make_inputs_dict(shape_dict, dtype_dict, inputs=None, fill_mode="random"): def run_module( - module_file, - device, - hostname=None, - port=9090, - rpc_key=None, - inputs=None, - fill_mode="random", - repeat=1, - profile=False, + tvmc_package: TVMCPackage, + device: str, + hostname: Optional[str] = None, + port: Union[int, str] = 9090, + rpc_key: Optional[str] = None, + inputs: Optional[Dict[str, np.ndarray]] = None, + fill_mode: str = "random", + repeat: int = 10, + number: int = 10, + profile: bool = False, ): """Run a compiled graph executor module locally or remotely with optional input values. @@ -307,8 +320,8 @@ def run_module( Parameters ---------- - module_file : str - The path to the module file (a .tar file). + tvmc_package: TVMCPackage + The compiled model package object that will be run. device: str, the device (e.g. "cpu" or "gpu") to be targeted by the RPC session, local or remote). @@ -320,13 +333,16 @@ def run_module( The tracker key of the target device. If this is set, it will be assumed that remote points to a tracker. inputs : dict, optional - A dictionary that maps input names to numpy values. + A dictionary that maps input names to numpy values. If not provided, + inputs will be generated using the fill_mode argument. fill_mode : str, optional The fill-mode to use when generating data for input tensors. Valid options are "zeros", "ones" and "random". Defaults to "random". repeat : int, optional How many times to repeat the run. + number : int, optional + The number of runs to measure within each repeat. profile : bool Whether to profile the run with the debug runtime. @@ -337,135 +353,73 @@ def run_module( times : list of str execution times generated by the time evaluator """ - - with tempfile.TemporaryDirectory() as tmp_dir: - logger.debug("extracting module file %s", module_file) - t = tarfile.open(module_file) - t.extractall(tmp_dir) - graph = open(os.path.join(tmp_dir, "mod.json")).read() - params = bytearray(open(os.path.join(tmp_dir, "mod.params"), "rb").read()) - - if hostname: - # Remote RPC - if rpc_key: - logger.debug("running on remote RPC tracker with key %s", rpc_key) - session = request_remote(rpc_key, hostname, port, timeout=1000) - else: - logger.debug("running on remote RPC with no key") - session = rpc.connect(hostname, port) - else: - # Local - logger.debug("running a local session") - session = rpc.LocalSession() - - session.upload(os.path.join(tmp_dir, "mod.so")) - lib = session.load_module("mod.so") - - # TODO expand to other supported devices, as listed in tvm.rpc.client (@leandron) - logger.debug("device is %s", device) - if device == "gpu": - dev = session.gpu() - elif device == "cl": - dev = session.cl() - else: - assert device == "cpu" - dev = session.cpu() - - if profile: - logger.debug("creating runtime with profiling enabled") - module = debug_executor.create(graph, lib, dev, dump_root="./prof") + if not isinstance(tvmc_package, TVMCPackage): + raise TVMCException( + "This model doesn't seem to have been compiled yet. " + "Try calling tvmc.compile on the model before running it." + ) + + if hostname: + if isinstance(port, str): + port = int(port) + # Remote RPC + if rpc_key: + logger.debug("Running on remote RPC tracker with key %s.", rpc_key) + session = request_remote(rpc_key, hostname, port, timeout=1000) else: - logger.debug("creating runtime with profiling disabled") - module = runtime.create(graph, lib, dev) - - logger.debug("load params into the runtime module") - module.load_params(params) - - shape_dict, dtype_dict = get_input_info(graph, params) - inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode) - - logger.debug("setting inputs to the module") - module.set_input(**inputs_dict) - - # Run must be called explicitly if profiling - if profile: - logger.debug("running the module with profiling enabled") - module.run() - - # create the module time evaluator (returns a function) - timer = module.module.time_evaluator("run", dev, 1, repeat=repeat) - # call the evaluator function to invoke the module and save execution times - prof_result = timer() - # collect a list of execution times from the profiling results - times = prof_result.results - - logger.debug("collecting the output tensors") - num_outputs = module.get_num_outputs() - outputs = {} - for i in range(num_outputs): - output_name = "output_{}".format(i) - outputs[output_name] = module.get_output(i).asnumpy() - - return outputs, times - - -def get_top_results(outputs, max_results): - """Return the top n results from the output tensor. - - This function is primarily for image classification and will - not necessarily generalise. - - Parameters - ---------- - outputs : dict - Outputs dictionary - {output_name: np.array}. - max_results : int - Number of results to return - - Returns - ------- - top_results : np.array - Results array of shape (2, n). - The first row is the indices and the second is the values. - - """ - output = np.copy(outputs["output_0"]) - sorted_labels = output.argsort()[0][-max_results:][::-1] - output.sort() - sorted_values = output[0][-max_results:][::-1] - top_results = np.array([sorted_labels, sorted_values]) - return top_results - - -def format_times(times): - """Format the mean, max, min and std of the execution times. - - This has the effect of producing a small table that looks like: - - Execution time summary: - mean (ms) max (ms) min (ms) std (ms) - 0.14310 0.16161 0.12933 0.01004 - - Parameters - ---------- - times : list - A list of execution times (in seconds). - - Returns - ------- - str - A formatted string containing the statistics. - """ - - # timestamps - mean_ts = np.mean(times) * 1000 - std_ts = np.std(times) * 1000 - max_ts = np.max(times) * 1000 - min_ts = np.min(times) * 1000 - - header = "Execution time summary:\n{0:^10} {1:^10} {2:^10} {3:^10}".format( - "mean (ms)", "max (ms)", "min (ms)", "std (ms)" - ) - stats = "{0:^10.2f} {1:^10.2f} {2:^10.2f} {3:^10.2f}".format(mean_ts, max_ts, min_ts, std_ts) + logger.debug("Running on remote RPC with no key.") + session = rpc.connect(hostname, port) + else: + # Local + logger.debug("Running a local session.") + session = rpc.LocalSession() + + session.upload(tvmc_package.lib_path) + lib = session.load_module(tvmc_package.lib_name) + + # TODO expand to other supported devices, as listed in tvm.rpc.client (@leandron) + logger.debug("Device is %s.", device) + if device == "gpu": + dev = session.gpu() + elif device == "cl": + dev = session.cl() + else: + assert device == "cpu" + dev = session.cpu() - return "%s\n%s\n" % (header, stats) + if profile: + logger.debug("Creating runtime with profiling enabled.") + module = debug_executor.create(tvmc_package.graph, lib, dev, dump_root="./prof") + else: + logger.debug("Creating runtime with profiling disabled.") + module = runtime.create(tvmc_package.graph, lib, dev) + + logger.debug("Loading params into the runtime module.") + module.load_params(tvmc_package.params) + + shape_dict, dtype_dict = get_input_info(tvmc_package.graph, tvmc_package.params) + inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode) + + logger.debug("Setting inputs to the module.") + module.set_input(**inputs_dict) + + # Run must be called explicitly if profiling + if profile: + logger.info("Running the module with profiling enabled.") + module.run() + + # create the module time evaluator (returns a function) + timer = module.module.time_evaluator("run", dev, number=number, repeat=repeat) + # call the evaluator function to invoke the module and save execution times + prof_result = timer() + # collect a list of execution times from the profiling results + times = prof_result.results + + logger.debug("Collecting the output tensors.") + num_outputs = module.get_num_outputs() + outputs = {} + for i in range(num_outputs): + output_name = "output_{}".format(i) + outputs[output_name] = module.get_output(i).asnumpy() + + return TVMCResult(outputs, times) diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index ffdfbccbdd34..80fb71d84388 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -167,6 +167,11 @@ TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams") max_threads_per_block, max_vthread_extent, warp_size); }); +TVM_REGISTER_GLOBAL("auto_scheduler.GetDefaultHardwareParams") + .set_body_typed([](Target target, Target target_host) { + return HardwareParamsNode::GetDefaultHardwareParams(target, target_host); + }); + TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask") .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index 3345b4f07585..f7cbf92bca30 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -41,7 +41,7 @@ def download_and_untar(model_url, model_sub_path, temp_dir): return os.path.join(temp_dir, model_sub_path) -def get_sample_compiled_module(target_dir): +def get_sample_compiled_module(target_dir, package_filename): """Support function that returns a TFLite compiled module""" base_url = "https://storage.googleapis.com/download.tensorflow.org/models" model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz" @@ -51,8 +51,10 @@ def get_sample_compiled_module(target_dir): temp_dir=target_dir, ) - mod, params = tvmc.frontends.load_model(model_file) - return tvmc.compiler.compile_model(mod, params, target="llvm") + tvmc_model = tvmc.frontends.load_model(model_file) + return tvmc.compiler.compile_model( + tvmc_model, target="llvm", package_path=os.path.join(target_dir, package_filename) + ) # PyTest fixtures @@ -100,6 +102,29 @@ def keras_resnet50(tmpdir_factory): return model_file_name +@pytest.fixture(scope="session") +def keras_simple(tmpdir_factory): + try: + from tensorflow import keras + except ImportError: + # not all environments provide TensorFlow, so skip this fixture + # if that is that case. + return "" + + model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "simple_conv.h5") + model = keras.Sequential( + [ + keras.layers.InputLayer(input_shape=[32, 32, 3], batch_size=1), + keras.layers.Conv2D(8, kernel_size=(3, 3)), + keras.layers.Flatten(), + keras.layers.Dense(64), + ] + ) + model.save(model_file_name) + + return model_file_name + + @pytest.fixture(scope="session") def pytorch_resnet18(tmpdir_factory): try: @@ -129,7 +154,18 @@ def onnx_resnet50(): @pytest.fixture(scope="session") -def tflite_compiled_module_as_tarfile(tmpdir_factory): +def onnx_mnist(): + base_url = "https://github.com/onnx/models/raw/master/vision/classification/mnist/model" + file_to_download = "mnist-1.onnx" + model_file = download_testdata( + "{}/{}".format(base_url, file_to_download), file_to_download, module=["tvmc"] + ) + + return model_file + + +@pytest.fixture(scope="session") +def tflite_compiled_model(tmpdir_factory): # Not all CI environments will have TFLite installed # so we need to safely skip this fixture that will @@ -143,12 +179,7 @@ def tflite_compiled_module_as_tarfile(tmpdir_factory): return "" target_dir = tmpdir_factory.mktemp("data") - graph, lib, params, _ = get_sample_compiled_module(target_dir) - - module_file = os.path.join(target_dir, "mock.tar") - tvmc.compiler.save_module(module_file, graph, lib, params) - - return module_file + return get_sample_compiled_module(target_dir, "mock.tar") @pytest.fixture(scope="session") diff --git a/tests/python/driver/tvmc/test_autoscheduler.py b/tests/python/driver/tvmc/test_autoscheduler.py index 25525eb9ce97..f1d750fa4078 100644 --- a/tests/python/driver/tvmc/test_autoscheduler.py +++ b/tests/python/driver/tvmc/test_autoscheduler.py @@ -14,10 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import json import pytest import os -import tarfile from os import path @@ -26,28 +24,30 @@ def _get_tasks(model): - mod, params = tvmc.frontends.load_model(model) - tasks, weights = tvmc.autotuner.autoscheduler_get_tuning_tasks(mod, params, "llvm") + tvmc_model = tvmc.frontends.load_model(model) + tasks, weights = tvmc.autotuner.autoscheduler_get_tuning_tasks( + tvmc_model.mod, tvmc_model.params, "llvm" + ) return (tasks, weights) -def _autoscheduler_test_helper( - model, tmpdir_name, tasks_weights=None, early_stopping=1, tuning_records=None -): - tasks, weights = tasks_weights if tasks_weights else _get_tasks(model) +def _autoscheduler_test_helper(model, tmpdir_name, early_stopping=1, prior_records=None): + tvmc_model = tvmc.frontends.load_model(model) log_file = os.path.join(tmpdir_name, "autoscheduler.json") - tuning_options = auto_scheduler.TuningOptions( - num_measure_trials=1, - measure_callbacks=[auto_scheduler.RecordToFile(log_file)], - runner="local", - builder="local", - verbose=0, + hardware_params = auto_scheduler.HardwareParams(num_cores=4, target="llvm") + + tvmc.tune( + tvmc_model, + target="llvm", + tuning_records=log_file, + prior_records=prior_records, early_stopping=early_stopping, + enable_autoscheduler=True, + trials=2, + hardware_params=hardware_params, ) - tvmc.autotuner.schedule_tasks(tasks[:1], weights[:1], tuning_options, tuning_records) - # testing whether the log file was produced assert path.exists(log_file), "autoscheduler log file should exist" @@ -59,10 +59,10 @@ def _autoscheduler_test_helper( return log_file -def test_get_tuning_tasks(onnx_resnet50): - pytest.importorskip("onnx") +def test_get_tuning_tasks(keras_simple): + pytest.importorskip("tensorflow") - tasks, weights = _get_tasks(onnx_resnet50) + tasks, weights = _get_tasks(keras_simple) expected_task_type = auto_scheduler.SearchTask assert type(tasks) is list @@ -70,32 +70,25 @@ def test_get_tuning_tasks(onnx_resnet50): assert all([type(x) is expected_task_type for x in tasks]) is True -def test_tune_tasks(onnx_resnet50, tmpdir_factory): - pytest.importorskip("onnx") +def test_tune_tasks(keras_simple, tmpdir_factory): + pytest.importorskip("tensorflow") tmpdir_name = tmpdir_factory.mktemp("data") - _autoscheduler_test_helper(onnx_resnet50, tmpdir_name) + _autoscheduler_test_helper(keras_simple, tmpdir_name) -def test_tune_tasks__tuning_records(onnx_resnet50, tmpdir_factory): - pytest.importorskip("onnx") +def test_tune_tasks__tuning_records(keras_simple, tmpdir_factory): + pytest.importorskip("tensorflow") tmpdir_name = tmpdir_factory.mktemp("data") - output_log_phase_1 = _autoscheduler_test_helper(onnx_resnet50, tmpdir_name) + output_log_phase_1 = _autoscheduler_test_helper(keras_simple, tmpdir_name) # Exercises transfer learning by making sure a previous log exists - _autoscheduler_test_helper(onnx_resnet50, tmpdir_name, tuning_records=output_log_phase_1) - - -def test_tune_tasks__no_early_stopping(onnx_resnet50, tmpdir_factory): - pytest.importorskip("onnx") - - tmpdir_name = tmpdir_factory.mktemp("data") - _autoscheduler_test_helper(onnx_resnet50, tmpdir_name, tasks_weights=None, early_stopping=None) + _autoscheduler_test_helper(keras_simple, tmpdir_name, prior_records=output_log_phase_1) -def test_tune_tasks__no_tuning_records(onnx_resnet50, tmpdir_factory): - pytest.importorskip("onnx") +def test_tune_tasks__no_early_stopping(keras_simple, tmpdir_factory): + pytest.importorskip("tensorflow") tmpdir_name = tmpdir_factory.mktemp("data") - _autoscheduler_test_helper(onnx_resnet50, tmpdir_name, tasks_weights=None, tuning_records=None) + _autoscheduler_test_helper(keras_simple, tmpdir_name, early_stopping=None) diff --git a/tests/python/driver/tvmc/test_autotuner.py b/tests/python/driver/tvmc/test_autotuner.py index 5ce4ca95c810..e82e33b158b1 100644 --- a/tests/python/driver/tvmc/test_autotuner.py +++ b/tests/python/driver/tvmc/test_autotuner.py @@ -14,10 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import json import pytest import os -import tarfile from os import path @@ -26,8 +24,8 @@ def _get_tasks(model): - mod, params = tvmc.frontends.load_model(model) - return tvmc.autotuner.autotvm_get_tuning_tasks(mod, params, "llvm") + tvmc_model = tvmc.frontends.load_model(model) + return tvmc.autotuner.autotvm_get_tuning_tasks(tvmc_model.mod, tvmc_model.params, "llvm") def _get_measure_options(): @@ -36,20 +34,18 @@ def _get_measure_options(): ) -def _tuner_test_helper( - model, tuner_name, tmpdir_name, tasks=None, early_stopping=1, tuning_records=None -): - tasks = tasks if tasks else _get_tasks(model) +def _tuner_test_helper(model, tuner_name, tmpdir_name, early_stopping=1, prior_records=None): + tvmc_model = tvmc.frontends.load_model(model) log_file = os.path.join(tmpdir_name, "log_{}.txt".format(tuner_name)) - tvmc.autotuner.tune_tasks( - tasks=[tasks[0]], - log_file=log_file, - measure_option=_get_measure_options(), + tvmc.tune( + tvmc_model, + target="llvm", + tuning_records=log_file, + prior_records=prior_records, tuner=tuner_name, - trials=1, + trials=4, early_stopping=early_stopping, - tuning_records=tuning_records, ) # testing whether the log file was produced @@ -63,10 +59,10 @@ def _tuner_test_helper( return log_file -def test_get_tuning_tasks(onnx_resnet50): +def test_get_tuning_tasks(onnx_mnist): pytest.importorskip("onnx") - sut = _get_tasks(onnx_resnet50) + sut = _get_tasks(onnx_mnist) expected_task_type = autotvm.task.Task assert type(sut) is list @@ -74,76 +70,85 @@ def test_get_tuning_tasks(onnx_resnet50): assert all([type(x) is expected_task_type for x in sut]) is True -def test_tune_tasks__tuner__xgb(onnx_resnet50, tmpdir_factory): +def test_tune_tasks__tuner__xgb(onnx_mnist, tmpdir_factory): pytest.importorskip("onnx") tmpdir_name = tmpdir_factory.mktemp("data") - _tuner_test_helper(onnx_resnet50, "xgb", tmpdir_name) + _tuner_test_helper(onnx_mnist, "xgb", tmpdir_name) -def test_tune_tasks__tuner__xgb_knob(onnx_resnet50, tmpdir_factory): +def test_tune_tasks__tuner__xgb_knob(onnx_mnist, tmpdir_factory): pytest.importorskip("onnx") tmpdir_name = tmpdir_factory.mktemp("data") - _tuner_test_helper(onnx_resnet50, "xgb_knob", tmpdir_name) + _tuner_test_helper(onnx_mnist, "xgb_knob", tmpdir_name) -def test_tune_tasks__tuner__ga(onnx_resnet50, tmpdir_factory): +def test_tune_tasks__tuner__ga(onnx_mnist, tmpdir_factory): pytest.importorskip("onnx") tmpdir_name = tmpdir_factory.mktemp("data") - _tuner_test_helper(onnx_resnet50, "ga", tmpdir_name) + _tuner_test_helper(onnx_mnist, "ga", tmpdir_name) -def test_tune_tasks__tuner__random(onnx_resnet50, tmpdir_factory): +def test_tune_tasks__tuner__random(onnx_mnist, tmpdir_factory): pytest.importorskip("onnx") tmpdir_name = tmpdir_factory.mktemp("data") - _tuner_test_helper(onnx_resnet50, "random", tmpdir_name) + _tuner_test_helper(onnx_mnist, "random", tmpdir_name) -def test_tune_tasks__tuner__gridsearch(onnx_resnet50, tmpdir_factory): +def test_tune_tasks__tuner__gridsearch(onnx_mnist, tmpdir_factory): pytest.importorskip("onnx") tmpdir_name = tmpdir_factory.mktemp("data") - _tuner_test_helper(onnx_resnet50, "gridsearch", tmpdir_name) + _tuner_test_helper(onnx_mnist, "gridsearch", tmpdir_name) -def test_tune_tasks__tuner__gridsearch__tuning_records(onnx_resnet50, tmpdir_factory): +def test_tune_tasks__tuner__gridsearch__tuning_records(onnx_mnist, tmpdir_factory): pytest.importorskip("onnx") tmpdir_name = tmpdir_factory.mktemp("data") - output_log_phase_1 = _tuner_test_helper(onnx_resnet50, "gridsearch", tmpdir_name) + output_log_phase_1 = _tuner_test_helper(onnx_mnist, "gridsearch", tmpdir_name) # Exercises transfer learning by making sure a previous log exists - _tuner_test_helper(onnx_resnet50, "gridsearch", tmpdir_name, tuning_records=output_log_phase_1) + _tuner_test_helper(onnx_mnist, "gridsearch", tmpdir_name, prior_records=output_log_phase_1) -def test_tune_tasks__tuner__ga__empty_tasks(onnx_resnet50, tmpdir_factory): +def test_tune_tasks__tuner__ga__empty_tasks(tmpdir_factory): pytest.importorskip("onnx") tmpdir_name = tmpdir_factory.mktemp("data") - _tuner_test_helper(onnx_resnet50, "ga", tmpdir_name, tasks=[]) + log_file = os.path.join(tmpdir_name, "log_{}.txt".format("ga")) + + tvmc.autotuner.tune_tasks( + tasks=[], + log_file=log_file, + measure_option=_get_measure_options(), + tuner="ga", + trials=1, + early_stopping=1, + ) -def test_tune_tasks__tuner__xgb__no_early_stopping(onnx_resnet50, tmpdir_factory): +def test_tune_tasks__tuner__xgb__no_early_stopping(onnx_mnist, tmpdir_factory): pytest.importorskip("onnx") tmpdir_name = tmpdir_factory.mktemp("data") - _tuner_test_helper(onnx_resnet50, "xgb", tmpdir_name, tasks=None, early_stopping=None) + _tuner_test_helper(onnx_mnist, "xgb", tmpdir_name, early_stopping=None) -def test_tune_tasks__tuner__xgb__no_tuning_records(onnx_resnet50, tmpdir_factory): +def test_tune_tasks__tuner__xgb__no_tuning_records(onnx_mnist, tmpdir_factory): pytest.importorskip("onnx") tmpdir_name = tmpdir_factory.mktemp("data") - _tuner_test_helper(onnx_resnet50, "xgb", tmpdir_name, tasks=None, tuning_records=None) + _tuner_test_helper(onnx_mnist, "xgb", tmpdir_name, prior_records=None) -def test_tune_tasks__invalid_tuner(onnx_resnet50, tmpdir_factory): +def test_tune_tasks__invalid_tuner(onnx_mnist, tmpdir_factory): pytest.importorskip("onnx") - tasks = _get_tasks(onnx_resnet50) + tasks = _get_tasks(onnx_mnist) log_file = os.path.join(tmpdir_factory.mktemp("data"), "log2.txt") with pytest.raises(tvmc.common.TVMCException): diff --git a/tests/python/driver/tvmc/test_command_line.py b/tests/python/driver/tvmc/test_command_line.py new file mode 100644 index 000000000000..66a32160522b --- /dev/null +++ b/tests/python/driver/tvmc/test_command_line.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import os + +from tvm.driver.tvmc.main import _main + + +def test_tvmc_cl_workflow(keras_simple, tmpdir_factory): + pytest.importorskip("tensorflow") + + tmpdir = tmpdir_factory.mktemp("data") + + # Test model tuning + log_path = os.path.join(tmpdir, "keras-autotuner_records.json") + tuning_str = ( + f"tvmc tune --target llvm --output {log_path} " + f"--trials 2 --enable-autoscheduler {keras_simple}" + ) + tuning_args = tuning_str.split(" ")[1:] + _main(tuning_args) + assert os.path.exists(log_path) + + # Test model compilation + package_path = os.path.join(tmpdir, "keras-tvm.tar") + compile_str = ( + f"tvmc compile --target llvm --tuning-records {log_path} " + f"--output {package_path} {keras_simple}" + ) + compile_args = compile_str.split(" ")[1:] + _main(compile_args) + assert os.path.exists(package_path) + + # Test running the model + output_path = os.path.join(tmpdir, "predictions.npz") + run_str = f"tvmc run --outputs {output_path} {package_path}" + run_args = run_str.split(" ")[1:] + _main(run_args) + assert os.path.exists(output_path) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 24fa452d05c1..a023689cc86d 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import argparse import os import shutil from os import path @@ -28,6 +27,7 @@ from tvm.contrib.target.vitis_ai import vitis_ai_available from tvm.driver import tvmc +from tvm.driver.tvmc.model import TVMCPackage def test_save_dumps(tmpdir_factory): @@ -45,16 +45,16 @@ def test_save_dumps(tmpdir_factory): def verify_compile_tflite_module(model, shape_dict=None): pytest.importorskip("tflite") - mod, params = tvmc.load(model, shape_dict=shape_dict) - graph, lib, params, dumps = tvmc.compile( - mod, params, target="llvm", dump_code="ll", alter_layout="NCHW" - ) + tvmc_model = tvmc.load(model, shape_dict=shape_dict) + tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW") + dumps_path = tvmc_package.package_path + ".ll" # check for output types - assert type(graph) is str - assert type(lib) is tvm.runtime.module.Module - assert type(params) is dict - assert type(dumps) is dict + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): @@ -75,35 +75,42 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): def test_cross_compile_aarch64_tflite_module(tflite_mobilenet_v1_1_quant): pytest.importorskip("tflite") - mod, params = tvmc.load(tflite_mobilenet_v1_1_quant) - graph, lib, params, dumps = tvmc.compile( - mod, - params, + tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) + tvmc_package = tvmc.compile( + tvmc_model, target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr='+neon'", dump_code="asm", + cross="aarch64-linux-gnu-gcc", ) + dumps_path = tvmc_package.package_path + ".asm" # check for output types - assert type(graph) is str - assert type(lib) is tvm.runtime.module.Module - assert type(params) is dict - assert type(dumps) is dict + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) def test_compile_keras__save_module(keras_resnet50, tmpdir_factory): # some CI environments wont offer tensorflow/Keras, so skip in case it is not present pytest.importorskip("tensorflow") - mod, params = tvmc.load(keras_resnet50) - graph, lib, params, dumps = tvmc.compile(mod, params, target="llvm", dump_code="ll") - expected_temp_dir = tmpdir_factory.mktemp("saved_output") expected_file_name = "saved.tar" module_file = os.path.join(expected_temp_dir, expected_file_name) - tvmc.compiler.save_module(module_file, graph, lib, params) + + tvmc_model = tvmc.load(keras_resnet50) + tvmc.compile(tvmc_model, target="llvm", dump_code="ll", package_path=module_file) assert os.path.exists(module_file), "output file {0} should exist".format(module_file) + # Test that we can load back in a module. + tvmc_package = TVMCPackage(package_path=module_file) + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.graph) is str + assert type(tvmc_package.params) is bytearray + # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @pytest.mark.skipif( @@ -113,34 +120,36 @@ def test_cross_compile_aarch64_keras_module(keras_resnet50): # some CI environments wont offer tensorflow/Keras, so skip in case it is not present pytest.importorskip("tensorflow") - mod, params = tvmc.load(keras_resnet50) - graph, lib, params, dumps = tvmc.compile( - mod, - params, + tvmc_model = tvmc.load(keras_resnet50) + tvmc_package = tvmc.compile( + tvmc_model, target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr='+neon'", dump_code="asm", + cross="aarch64-linux-gnu-gcc", ) + dumps_path = tvmc_package.package_path + ".asm" # check for output types - assert type(graph) is str - assert type(lib) is tvm.runtime.module.Module - assert type(params) is dict - assert type(dumps) is dict - assert "asm" in dumps.keys() + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) def verify_compile_onnx_module(model, shape_dict=None): # some CI environments wont offer onnx, so skip in case it is not present pytest.importorskip("onnx") - mod, params = tvmc.load(model, shape_dict=shape_dict) - graph, lib, params, dumps = tvmc.compile(mod, params, target="llvm", dump_code="ll") + tvmc_model = tvmc.load(model, shape_dict=shape_dict) + tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll") + dumps_path = tvmc_package.package_path + ".ll" # check for output types - assert type(graph) is str - assert type(lib) is tvm.runtime.module.Module - assert type(params) is dict - assert type(dumps) is dict - assert "ll" in dumps.keys() + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) def test_compile_onnx_module(onnx_resnet50): @@ -160,38 +169,40 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50): # some CI environments wont offer onnx, so skip in case it is not present pytest.importorskip("onnx") - mod, params = tvmc.load(onnx_resnet50) - graph, lib, params, dumps = tvmc.compile( - mod, - params, + tvmc_model = tvmc.load(onnx_resnet50) + tvmc_package = tvmc.compile( + tvmc_model, target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", dump_code="asm", + cross="aarch64-linux-gnu-gcc", ) + dumps_path = tvmc_package.package_path + ".asm" # check for output types - assert type(graph) is str - assert type(lib) is tvm.runtime.module.Module - assert type(params) is dict - assert type(dumps) is dict - assert "asm" in dumps.keys() + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) @tvm.testing.requires_opencl def test_compile_opencl(tflite_mobilenet_v1_0_25_128): pytest.importorskip("tflite") - mod, params = tvmc.load(tflite_mobilenet_v1_0_25_128) - graph, lib, params, dumps = tvmc.compile( - mod, - params, + tvmc_model = tvmc.load(tflite_mobilenet_v1_0_25_128) + tvmc_package = tvmc.compile( + tvmc_model, target="opencl --host=llvm", - alter_layout="NCHW", + desired_layout="NCHW", ) + dumps_path = tvmc_package.package_path + ".asm" # check for output types - assert type(graph) is str - assert type(lib) is tvm.runtime.module.Module - assert type(params) is dict - assert type(dumps) is dict + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) @pytest.mark.skipif( @@ -200,16 +211,16 @@ def test_compile_opencl(tflite_mobilenet_v1_0_25_128): ) def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant): pytest.importorskip("tflite") - mod, params = tvmc.load(tflite_mobilenet_v1_1_quant) - graph, lib, params, dumps = tvmc.compile( - mod, params, target="ethos-n77, llvm", dump_code="relay" - ) + tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) + tvmc_package = tvmc.compile(tvmc_model, target="ethos-n77, llvm", dump_code="relay") + dumps_path = tvmc_package.package_path + ".relay" # check for output types - assert type(graph) is str - assert type(lib) is tvm.runtime.module.Module - assert type(params) is dict - assert type(dumps) is dict + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) @pytest.mark.skipif( @@ -219,36 +230,38 @@ def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant def test_compile_tflite_module_with_external_codegen_vitis_ai(tflite_mobilenet_v1_1_quant): pytest.importorskip("tflite") - mod, params = tvmc.load(tflite_mobilenet_v1_1_quant) - graph, lib, params, dumps = tvmc.compiler.compile_model( - mod, - params, + tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) + tvmc_package = tvmc.compiler.compile_model( + tvmc_model, target="vitis-ai -dpu=DPUCZDX8G-zcu104 -export_runtime_module=vitis_ai.rtmod, llvm", dump_code="relay", ) + dumps_path = tvmc_package.package_path + ".relay" # check for output types - assert type(graph) is str - assert type(lib) is tvm.runtime.module.Module - assert type(params) is dict - assert type(dumps) is dict + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) @mock.patch("tvm.relay.build") @mock.patch("tvm.driver.tvmc.composite_target.get_codegen_by_target") @mock.patch("tvm.driver.tvmc.load") @mock.patch("tvm.transform.PassContext") -def test_compile_check_configs_composite_target(mock_pc, mock_fe, mock_ct, mock_relay): +@mock.patch("tvm.driver.tvmc.model.TVMCPackage.__init__", return_value=None) +def test_compile_check_configs_composite_target(mock_pkg, mock_pc, mock_fe, mock_ct, mock_relay): mock_codegen = {} mock_codegen["config_key"] = "relay.ext.mock.options" mock_codegen["pass_pipeline"] = lambda *args, **kwargs: None - mock_fe.return_value = (None, None) + mock_fe.return_value = mock.MagicMock() mock_ct.return_value = mock_codegen mock_relay.return_value = mock.MagicMock() - mod, params = tvmc.load("no_file_needed") - graph, lib, params, dumps = tvmc.compile(mod, params, target="mockcodegen -testopt=value, llvm") + tvmc_model = tvmc.load("no_file_needed") + tvmc.compile(tvmc_model, target="mockcodegen -testopt=value, llvm") mock_pc.assert_called_once_with( opt_level=3, diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index 3da63d43ef29..adf62eb5c7e6 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -23,6 +23,7 @@ from tvm.driver import tvmc from tvm.driver.tvmc.common import TVMCException +from tvm.driver.tvmc.model import TVMCModel def test_get_frontends_contains_only_strings(): @@ -108,11 +109,12 @@ def test_load_model__tflite(tflite_mobilenet_v1_1_quant): # some CI environments wont offer TFLite, so skip in case it is not present pytest.importorskip("tflite") - mod, params = tvmc.load(tflite_mobilenet_v1_1_quant) - assert type(mod) is IRModule - assert type(params) is dict + tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) + assert type(tvmc_model) is TVMCModel + assert type(tvmc_model.mod) is IRModule + assert type(tvmc_model.params) is dict # check whether one known value is part of the params dict - assert "_param_1" in params.keys() + assert "_param_1" in tvmc_model.params.keys() @pytest.mark.parametrize("load_model_kwargs", [{}, {"layout": "NCHW"}]) @@ -120,40 +122,43 @@ def test_load_model__keras(keras_resnet50, load_model_kwargs): # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present pytest.importorskip("tensorflow") - mod, params = tvmc.frontends.load_model(keras_resnet50, **load_model_kwargs) - assert type(mod) is IRModule - assert type(params) is dict + tvmc_model = tvmc.frontends.load_model(keras_resnet50, **load_model_kwargs) + assert type(tvmc_model) is TVMCModel + assert type(tvmc_model.mod) is IRModule + assert type(tvmc_model.params) is dict ## check whether one known value is part of the params dict - assert "_param_1" in params.keys() + assert "_param_1" in tvmc_model.params.keys() def verify_load_model__onnx(model, **kwargs): - mod, params = tvmc.frontends.load_model(model, **kwargs) - assert type(mod) is IRModule - assert type(params) is dict - return mod, params + tvmc_model = tvmc.frontends.load_model(model, **kwargs) + assert type(tvmc_model) is TVMCModel + assert type(tvmc_model.mod) is IRModule + assert type(tvmc_model.params) is dict + return tvmc_model def test_load_model__onnx(onnx_resnet50): # some CI environments wont offer onnx, so skip in case it is not present pytest.importorskip("onnx") - mod, params = verify_load_model__onnx(onnx_resnet50) + tvmc_model = verify_load_model__onnx(onnx_resnet50) # check whether one known value is part of the params dict - assert "resnetv24_batchnorm0_gamma" in params.keys() - mod, params = verify_load_model__onnx(onnx_resnet50, freeze_params=True) + assert "resnetv24_batchnorm0_gamma" in tvmc_model.params.keys() + tvmc_model = verify_load_model__onnx(onnx_resnet50, freeze_params=True) # check that the parameter dict is empty, implying that they have been folded into constants - assert params == {} + assert tvmc_model.params == {} def test_load_model__pb(pb_mobilenet_v1_1_quant): # some CI environments wont offer TensorFlow, so skip in case it is not present pytest.importorskip("tensorflow") - mod, params = tvmc.load(pb_mobilenet_v1_1_quant) - assert type(mod) is IRModule - assert type(params) is dict + tvmc_model = tvmc.load(pb_mobilenet_v1_1_quant) + assert type(tvmc_model) is TVMCModel + assert type(tvmc_model.mod) is IRModule + assert type(tvmc_model.params) is dict # check whether one known value is part of the params dict - assert "MobilenetV1/Conv2d_0/weights" in params.keys() + assert "MobilenetV1/Conv2d_0/weights" in tvmc_model.params.keys() def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant): @@ -188,11 +193,12 @@ def test_load_model__pth(pytorch_resnet18): pytest.importorskip("torch") pytest.importorskip("torchvision") - mod, params = tvmc.load(pytorch_resnet18, shape_dict={"input": [1, 3, 224, 224]}) - assert type(mod) is IRModule - assert type(params) is dict + tvmc_model = tvmc.load(pytorch_resnet18, shape_dict={"input": [1, 3, 224, 224]}) + assert type(tvmc_model) is TVMCModel + assert type(tvmc_model.mod) is IRModule + assert type(tvmc_model.params) is dict # check whether one known value is part of the params dict - assert "layer1.0.conv1.weight" in params.keys() + assert "layer1.0.conv1.weight" in tvmc_model.params.keys() def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): diff --git a/tests/python/driver/tvmc/test_model.py b/tests/python/driver/tvmc/test_model.py new file mode 100644 index 000000000000..f5a28d419cbb --- /dev/null +++ b/tests/python/driver/tvmc/test_model.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import os + +from os import path + +from tvm.driver import tvmc +from tvm.driver.tvmc.model import TVMCModel, TVMCPackage, TVMCResult + + +def test_tvmc_workflow(keras_simple): + pytest.importorskip("tensorflow") + + tvmc_model = tvmc.load(keras_simple) + tuning_records = tvmc.tune(tvmc_model, target="llvm", enable_autoscheduler=True, trials=2) + tvmc_package = tvmc.compile(tvmc_model, tuning_records=tuning_records, target="llvm") + result = tvmc.run(tvmc_package, device="cpu") + assert type(tvmc_model) is TVMCModel + assert type(tvmc_package) is TVMCPackage + assert type(result) is TVMCResult + assert path.exists(tuning_records) + assert type(result.outputs) is dict + assert type(result.times) is tuple + assert "output_0" in result.outputs.keys() + + +def test_save_load_model(keras_simple, tmpdir_factory): + pytest.importorskip("onnx") + + tmpdir = tmpdir_factory.mktemp("data") + tvmc_model = tvmc.load(keras_simple) + + # Create tuning artifacts + tvmc.tune(tvmc_model, target="llvm", trials=2) + + # Create package artifacts + tvmc.compile(tvmc_model, target="llvm") + + # Save the model to disk + model_path = os.path.join(tmpdir, "saved_model.tar") + tvmc_model.save(model_path) + + # Load the model into a new TVMCModel + new_tvmc_model = TVMCModel(model_path=model_path) + + # Check that the two models match. + assert str(new_tvmc_model.mod) == str(tvmc_model.mod) + # Check that tuning records and the compiled package are recoverable. + assert path.exists(new_tvmc_model.default_package_path()) + assert path.exists(new_tvmc_model.default_tuning_records_path()) diff --git a/tests/python/driver/tvmc/test_runner.py b/tests/python/driver/tvmc/test_runner.py index cbea7e3d9d2b..5277a790b43b 100644 --- a/tests/python/driver/tvmc/test_runner.py +++ b/tests/python/driver/tvmc/test_runner.py @@ -18,6 +18,8 @@ import numpy as np from tvm.driver import tvmc +from tvm.driver.tvmc.model import TVMCResult +from tvm.driver.tvmc.result_utils import get_top_results def test_generate_tensor_data_zeros(): @@ -50,14 +52,16 @@ def test_generate_tensor_data__type_unknown(): def test_format_times__contains_header(): - sut = tvmc.runner.format_times([0.6, 1.2, 0.12, 0.42]) + fake_result = TVMCResult(outputs=None, times=[0.6, 1.2, 0.12, 0.42]) + sut = fake_result.format_times() assert "std (ms)" in sut def test_get_top_results_keep_results(): fake_outputs = {"output_0": np.array([[1, 2, 3, 4], [5, 6, 7, 8]])} + fake_result = TVMCResult(outputs=fake_outputs, times=None) number_of_results_wanted = 3 - sut = tvmc.runner.get_top_results(fake_outputs, number_of_results_wanted) + sut = get_top_results(fake_result, number_of_results_wanted) expected_number_of_lines = 2 assert len(sut) == expected_number_of_lines @@ -67,16 +71,14 @@ def test_get_top_results_keep_results(): assert len(sut[1]) == expected_number_of_results_per_line -def test_run_tflite_module__with_profile__valid_input( - tflite_compiled_module_as_tarfile, imagenet_cat -): +def test_run_tflite_module__with_profile__valid_input(tflite_compiled_model, imagenet_cat): # some CI environments wont offer TFLite, so skip in case it is not present pytest.importorskip("tflite") inputs = np.load(imagenet_cat) - outputs, times = tvmc.run( - tflite_compiled_module_as_tarfile, + result = tvmc.run( + tflite_compiled_model, inputs=inputs, hostname=None, device="cpu", @@ -84,7 +86,7 @@ def test_run_tflite_module__with_profile__valid_input( ) # collect the top 5 results - top_5_results = tvmc.runner.get_top_results(outputs, 5) + top_5_results = get_top_results(result, 5) top_5_ids = top_5_results[0] # IDs were collected from this reference: @@ -95,6 +97,6 @@ def test_run_tflite_module__with_profile__valid_input( assert ( tiger_cat_mobilenet_id in top_5_ids ), "tiger cat is expected in the top-5 for mobilenet v1" - assert type(outputs) is dict - assert type(times) is tuple - assert "output_0" in outputs.keys() + assert type(result.outputs) is dict + assert type(result.times) is tuple + assert "output_0" in result.outputs.keys() diff --git a/tests/python/driver/tvmc/test_tvmc_common.py b/tests/python/driver/tvmc/test_tvmc_common.py index 474649d8b1b3..078076b479ea 100644 --- a/tests/python/driver/tvmc/test_tvmc_common.py +++ b/tests/python/driver/tvmc/test_tvmc_common.py @@ -15,13 +15,10 @@ # specific language governing permissions and limitations # under the License. import argparse -import os -from os import path import pytest import tvm -from tvm import relay from tvm.driver import tvmc from tvm.driver.tvmc.common import TVMCException @@ -31,7 +28,8 @@ def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant): # some CI environments wont offer TFLite, so skip in case it is not present pytest.importorskip("tflite") - before, _ = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + before = tvmc_model.mod expected_layout = "NCHW" after = tvmc.common.convert_graph_layout(before, expected_layout) @@ -55,7 +53,8 @@ def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50): # some CI environments wont offer ONNX, so skip in case it is not present pytest.importorskip("onnx") - before, _ = tvmc.frontends.load_model(onnx_resnet50) + tvmc_model = tvmc.frontends.load_model(onnx_resnet50) + before = tvmc_model.mod expected_layout = "NHWC" after = tvmc.common.convert_graph_layout(before, expected_layout) @@ -79,7 +78,8 @@ def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_ # some CI environments wont offer TFLite, so skip in case it is not present pytest.importorskip("tflite") - before, _ = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + before = tvmc_model.mod expected_layout = "NHWC" after = tvmc.common.convert_graph_layout(before, expected_layout) @@ -103,7 +103,8 @@ def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50): # some CI environments wont offer ONNX, so skip in case it is not present pytest.importorskip("onnx") - before, _ = tvmc.frontends.load_model(onnx_resnet50) + tvmc_model = tvmc.frontends.load_model(onnx_resnet50) + before = tvmc_model.mod expected_layout = "NCHW" after = tvmc.common.convert_graph_layout(before, expected_layout)