From c17c625ba77f646debe9440696615b3e765061d4 Mon Sep 17 00:00:00 2001 From: Leandro Nunes Date: Fri, 2 Oct 2020 03:25:56 +0100 Subject: [PATCH] [tvmc] Introduce 'run' subcommand (part 4/4) (#6578) * [tvmc] Introduce 'run' subcommand (part 4/4) * Add 'tvmc run' subcommand to execute compiled modules * Include options to locally or remotelly using RPC * Include support to cpu and gpu devices Co-authored-by: Marcus Shawcroft Co-authored-by: Matthew Barrett * adjust based on code review comments * make test fixture to safely skip environments without tflite * make --help option more clear * improve error message to show expected inputs * code-review adjusts * update doc-string to default zeros->random Co-authored-by: Marcus Shawcroft Co-authored-by: Matthew Barrett --- python/tvm/driver/tvmc/__init__.py | 1 + python/tvm/driver/tvmc/common.py | 35 ++ python/tvm/driver/tvmc/compiler.py | 4 +- python/tvm/driver/tvmc/runner.py | 464 ++++++++++++++++++++++++ tests/python/driver/tvmc/conftest.py | 41 ++- tests/python/driver/tvmc/test_common.py | 31 ++ tests/python/driver/tvmc/test_runner.py | 98 +++++ 7 files changed, 667 insertions(+), 7 deletions(-) create mode 100644 python/tvm/driver/tvmc/runner.py create mode 100644 tests/python/driver/tvmc/test_runner.py diff --git a/python/tvm/driver/tvmc/__init__.py b/python/tvm/driver/tvmc/__init__.py index 5926ca42e3639..d96a725877ebd 100644 --- a/python/tvm/driver/tvmc/__init__.py +++ b/python/tvm/driver/tvmc/__init__.py @@ -20,3 +20,4 @@ from . import autotuner from . import compiler +from . import runner diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 63e0d708f5277..a625a99f0e7e5 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -20,6 +20,8 @@ import logging import os.path +from urllib.parse import urlparse + import tvm from tvm import relay @@ -102,3 +104,36 @@ def target_from_cli(target): logger.debug("creating target from input: %s", target) return tvm.target.Target(target) + + +def tracker_host_port_from_cli(rpc_tracker_str): + """Extract hostname and (optional) port from strings + like "1.2.3.4:9090" or "4.3.2.1". + + Used as a helper function to cover --rpc-tracker + command line argument, in different subcommands. + + Parameters + ---------- + rpc_tracker_str : str + hostname (or IP address) and port of the RPC tracker, + in the format 'hostname[:port]'. + + Returns + ------- + rpc_hostname : str or None + hostname or IP address, extracted from input. + rpc_port : int or None + port number extracted from input (9090 default). + """ + + rpc_hostname = rpc_port = None + + if rpc_tracker_str: + parsed_url = urlparse("//%s" % rpc_tracker_str) + rpc_hostname = parsed_url.hostname + rpc_port = parsed_url.port or 9090 + logger.info("RPC tracker hostname: %s", rpc_hostname) + logger.info("RPC tracker port: %s", rpc_port) + + return rpc_hostname, rpc_port diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 831cec7446895..8001ee29f757d 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -190,8 +190,6 @@ def compile_model( target_host = target_host or "" if tuning_records and os.path.exists(tuning_records): - # TODO (@leandron) a new PR will introduce the 'tune' subcommand - # the is used to generate the tuning records file logger.debug("tuning records file provided: %s", tuning_records) with autotvm.apply_history_best(tuning_records): with tvm.transform.PassContext(opt_level=3): @@ -212,6 +210,8 @@ def compile_model( source = str(mod) if source_type == "relay" else lib.get_source(source_type) dumps[source_type] = source + # TODO we need to update this return to use the updated graph module APIs + # as these getter functions will be deprecated in the next release (@leandron) return graph_module.get_json(), graph_module.get_lib(), graph_module.get_params(), dumps diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py new file mode 100644 index 0000000000000..d86d4db795dcd --- /dev/null +++ b/python/tvm/driver/tvmc/runner.py @@ -0,0 +1,464 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Provides support to run compiled networks both locally and remotely. +""" +import json +import logging +import os +import tarfile +import tempfile + +import numpy as np +import tvm +from tvm import rpc +from tvm.autotvm.measure import request_remote +from tvm.contrib import graph_runtime as runtime +from tvm.contrib.debugger import debug_runtime + +from . import common +from .common import TVMCException +from .main import register_parser + + +# pylint: disable=invalid-name +logger = logging.getLogger("TVMC") + + +@register_parser +def add_run_parser(subparsers): + """ Include parser for 'run' subcommand """ + + parser = subparsers.add_parser("run", help="run a compiled module") + parser.set_defaults(func=drive_run) + + # TODO --device needs to be extended and tested to support other targets, + # like 'cl', 'webgpu', etc (@leandron) + parser.add_argument( + "--device", + choices=["cpu", "gpu"], + default="cpu", + help="target device to run the compiled module. Defaults to 'cpu'", + ) + parser.add_argument( + "--fill-mode", + choices=["zeros", "ones", "random"], + default="random", + help="fill all input tensors with values. In case --inputs/-i is provided, " + "they will take precedence over --fill-mode. Any remaining inputs will be " + "filled using the chosen fill mode. Defaults to 'random'", + ) + parser.add_argument("-i", "--inputs", help="path to the .npz input file") + parser.add_argument("-o", "--outputs", help="path to the .npz output file") + parser.add_argument( + "--print-time", action="store_true", help="record and print the execution time(s)" + ) + parser.add_argument( + "--print-top", + metavar="N", + type=int, + help="print the top n values and indices of the output tensor", + ) + parser.add_argument( + "--profile", + action="store_true", + help="generate profiling data from the runtime execution. " + "Using --profile requires the Graph Runtime Debug enabled on TVM. " + "Profiling may also have an impact on inference time, " + "making it take longer to be generated.", + ) + parser.add_argument( + "--repeat", metavar="N", type=int, default=1, help="repeat the run n times. Defaults to '1'" + ) + parser.add_argument( + "--rpc-key", + nargs=1, + help="the RPC tracker key of the target device", + ) + parser.add_argument( + "--rpc-tracker", + nargs=1, + help="hostname (required) and port (optional, defaults to 9090) of the RPC tracker, " + "e.g. '192.168.0.100:9999'", + ) + parser.add_argument("FILE", help="path to the compiled module file") + + +def drive_run(args): + """Invoke runner module with command line arguments + + Parameters + ---------- + args: argparse.Namespace + Arguments from command line parser. + """ + + rpc_hostname, rpc_port = common.tracker_host_port_from_cli(args.rpc_tracker) + + outputs, times = run_module( + args.FILE, + rpc_hostname, + rpc_port, + args.rpc_key, + inputs_file=args.inputs, + device=args.device, + fill_mode=args.fill_mode, + repeat=args.repeat, + profile=args.profile, + ) + + if args.print_time: + stat_table = format_times(times) + # print here is intentional + print(stat_table) + + if args.print_top: + top_results = get_top_results(outputs, args.print_top) + # print here is intentional + print(top_results) + + if args.outputs: + # Save the outputs + np.savez(args.outputs, **outputs) + + +def get_input_info(graph_str, params): + """Return the 'shape' and 'dtype' dictionaries for the input + tensors of a compiled module. + + .. note:: + We can't simply get the input tensors from a TVM graph + because weight tensors are treated equivalently. Therefore, to + find the input tensors we look at the 'arg_nodes' in the graph + (which are either weights or inputs) and check which ones don't + appear in the params (where the weights are stored). These nodes + are therefore inferred to be input tensors. + + Parameters + ---------- + graph_str : str + JSON graph of the module serialized as a string. + params : bytearray + Params serialized as a bytearray. + + Returns + ------- + shape_dict : dict + Shape dictionary - {input_name: tuple}. + dtype_dict : dict + dtype dictionary - {input_name: dtype}. + """ + + shape_dict = {} + dtype_dict = {} + # Use a special function to load the binary params back into a dict + load_arr = tvm.get_global_func("tvm.relay._load_param_dict")(params) + param_names = [v.name for v in load_arr] + graph = json.loads(graph_str) + for node_id in graph["arg_nodes"]: + node = graph["nodes"][node_id] + # If a node is not in the params, infer it to be an input node + name = node["name"] + if name not in param_names: + shape_dict[name] = graph["attrs"]["shape"][1][node_id] + dtype_dict[name] = graph["attrs"]["dltype"][1][node_id] + + logger.debug("collecting graph input shape and type:") + logger.debug("graph input shape: %s", shape_dict) + logger.debug("graph input type: %s", dtype_dict) + + return shape_dict, dtype_dict + + +def generate_tensor_data(shape, dtype, fill_mode): + """Generate data to produce a tensor of given shape and dtype. + + Random data generation depends on the dtype. For int8 types, + random integers in the range 0->255 are generated. For all other + types, random floats are generated in the range -1->1 and then + cast to the appropriate dtype. + + This is used to quickly generate some data to input the models, as + a way to check that compiled module is sane for running. + + Parameters + ---------- + shape : tuple + The shape of the tensor. + dtype : str + The dtype of the tensor. + fill_mode : str + The fill-mode to use, either "zeros", "ones" or "random". + + Returns + ------- + tensor : np.array + The generated tensor as a np.array. + """ + if fill_mode == "zeros": + tensor = np.zeros(shape=shape, dtype=dtype) + elif fill_mode == "ones": + tensor = np.ones(shape=shape, dtype=dtype) + elif fill_mode == "random": + if "int8" in dtype: + tensor = np.random.randint(128, size=shape, dtype=dtype) + else: + tensor = np.random.uniform(-1, 1, size=shape).astype(dtype) + else: + raise TVMCException("unknown fill-mode: {}".format(fill_mode)) + + return tensor + + +def make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode): + """Make the inputs dictionary for a graph. + + Use data from 'inputs' where specified. For input tensors + where no data has been given, generate data according to the + chosen fill-mode. + + Parameters + ---------- + inputs_file : str + Path to a .npz file containing the inputs. + shape_dict : dict + Shape dictionary - {input_name: tuple}. + dtype_dict : dict + dtype dictionary - {input_name: dtype}. + fill_mode : str + The fill-mode to use when generating tensor data. + Can be either "zeros", "ones" or "random". + + Returns + ------- + inputs_dict : dict + Complete inputs dictionary - {input_name: np.array}. + """ + logger.debug("creating inputs dict") + + try: + inputs = np.load(inputs_file) if inputs_file else {} + except IOError as ex: + raise TVMCException("Error loading inputs file: %s" % ex) + + # First check all the keys in inputs exist in the graph + for input_name in inputs: + if input_name not in shape_dict.keys(): + raise TVMCException( + "the input tensor '{}' is not in the graph. Expected inputs: '{}'".format( + input_name, shape_dict.keys() + ) + ) + + # Now construct the input dict, generating tensors where no + # data already exists in 'inputs' + inputs_dict = {} + for input_name in shape_dict: + if input_name in inputs.keys(): + logger.debug("setting input '%s' with user input data", input_name) + inputs_dict[input_name] = inputs[input_name] + else: + shape = shape_dict[input_name] + dtype = dtype_dict[input_name] + + logger.debug( + "generating data for input '%s' (shape: %s, dtype: %s), using fill-mode '%s'", + input_name, + shape, + dtype, + fill_mode, + ) + data = generate_tensor_data(shape, dtype, fill_mode) + inputs_dict[input_name] = data + + return inputs_dict + + +def run_module( + module_file, + hostname, + port=9090, + rpc_key=None, + device=None, + inputs_file=None, + fill_mode="random", + repeat=1, + profile=False, +): + """Run a compiled graph runtime module locally or remotely with + optional input values. + + If input tensors are not specified explicitly, they can be filled + with zeroes, ones or random data. + + Parameters + ---------- + module_file : str + The path to the module file (a .tar file). + hostname : str + The hostname of the target device on which to run. + port : int, optional + The port of the target device on which to run. + rpc_key : str, optional + The tracker key of the target device. If this is set, it + will be assumed that remote points to a tracker. + device: str, optional + the device (e.g. "cpu" or "gpu") to be targeted by the RPC + session, local or remote). + inputs_file : str, optional + Path to an .npz file containing the inputs. + fill_mode : str, optional + The fill-mode to use when generating data for input tensors. + Valid options are "zeros", "ones" and "random". + Defaults to "random". + repeat : int, optional + How many times to repeat the run. + profile : bool + Whether to profile the run with the debug runtime. + + Returns + ------- + outputs : dict + a dictionary with output tensors, generated by the module + times : list of str + execution times generated by the time evaluator + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + logger.debug("extracting module file %s", module_file) + t = tarfile.open(module_file) + t.extractall(tmp_dir) + graph = open(os.path.join(tmp_dir, "mod.json")).read() + params = bytearray(open(os.path.join(tmp_dir, "mod.params"), "rb").read()) + + if hostname: + # Remote RPC + if rpc_key: + logger.debug("running on remote RPC tracker with key %s", rpc_key) + session = request_remote(rpc_key, hostname, port, timeout=1000) + else: + logger.debug("running on remote RPC with no key") + session = rpc.connect(hostname, port) + else: + # Local + logger.debug("running a local session") + session = rpc.LocalSession() + + session.upload(os.path.join(tmp_dir, "mod.so")) + lib = session.load_module("mod.so") + + # TODO expand to other supported devices, as listed in tvm.rpc.client (@leandron) + logger.debug("device is %s", device) + ctx = session.cpu() if device == "cpu" else session.gpu() + + if profile: + logger.debug("creating runtime with profiling enabled") + module = debug_runtime.create(graph, lib, ctx, dump_root="./prof") + else: + logger.debug("creating runtime with profiling disabled") + module = runtime.create(graph, lib, ctx) + + logger.debug("load params into the runtime module") + module.load_params(params) + + shape_dict, dtype_dict = get_input_info(graph, params) + inputs_dict = make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode) + + logger.debug("setting inputs to the module") + module.set_input(**inputs_dict) + + # Run must be called explicitly if profiling + if profile: + logger.debug("running the module with profiling enabled") + module.run() + + # create the module time evaluator (returns a function) + timer = module.module.time_evaluator("run", ctx, 1, repeat=repeat) + # call the evaluator function to invoke the module and save execution times + prof_result = timer() + # collect a list of execution times from the profiling results + times = prof_result.results + + logger.debug("collecting the output tensors") + num_outputs = module.get_num_outputs() + outputs = {} + for i in range(num_outputs): + output_name = "output_{}".format(i) + outputs[output_name] = module.get_output(i).asnumpy() + + return outputs, times + + +def get_top_results(outputs, max_results): + """Return the top n results from the output tensor. + + This function is primarily for image classification and will + not necessarily generalise. + + Parameters + ---------- + outputs : dict + Outputs dictionary - {output_name: np.array}. + max_results : int + Number of results to return + + Returns + ------- + top_results : np.array + Results array of shape (2, n). + The first row is the indices and the second is the values. + + """ + output = outputs["output_0"] + sorted_labels = output.argsort()[0][-max_results:][::-1] + output.sort() + sorted_values = output[0][-max_results:][::-1] + top_results = np.array([sorted_labels, sorted_values]) + return top_results + + +def format_times(times): + """Format the mean, max, min and std of the execution times. + + This has the effect of producing a small table that looks like: + + Execution time summary: + mean (s) max (s) min (s) std (s) + 0.14310 0.16161 0.12933 0.01004 + + Parameters + ---------- + times : list + A list of execution times (in seconds). + + Returns + ------- + str + A formatted string containing the statistics. + """ + + # timestamps + mean_ts = np.mean(times) + std_ts = np.std(times) + max_ts = np.max(times) + min_ts = np.min(times) + + header = "Execution time summary:\n{0:^10} {1:^10} {2:^10} {3:^10}".format( + "mean (s)", "max (s)", "min (s)", "std (s)" + ) + stats = "{0:^10.5f} {1:^10.5f} {2:^10.5f} {3:^10.5f}".format(mean_ts, max_ts, min_ts, std_ts) + return "%s\n%s\n" % (header, stats) diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index ee67cc904aace..21ebb0f96bbc6 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -18,11 +18,13 @@ import pytest import tarfile -import tvm.driver.tvmc.compiler +import numpy as np -from tvm.contrib.download import download_testdata +from PIL import Image + +from tvm.driver import tvmc -from tvm.driver.tvmc.common import convert_graph_layout +from tvm.contrib.download import download_testdata # Support functions @@ -40,7 +42,7 @@ def download_and_untar(model_url, model_sub_path, temp_dir): def get_sample_compiled_module(target_dir): - """Support function that retuns a TFLite compiled module""" + """Support function that returns a TFLite compiled module""" base_url = "https://storage.googleapis.com/download.tensorflow.org/models" model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz" model_file = download_and_untar( @@ -49,7 +51,7 @@ def get_sample_compiled_module(target_dir): temp_dir=target_dir, ) - return tvmc.compiler.compile_model(model_file, targets=["llvm"]) + return tvmc.compiler.compile_model(model_file, target="llvm") # PyTest fixtures @@ -110,6 +112,18 @@ def onnx_resnet50(): @pytest.fixture(scope="session") def tflite_compiled_module_as_tarfile(tmpdir_factory): + + # Not all CI environments will have TFLite installed + # so we need to safely skip this fixture that will + # crash the tests that rely on it. + # As this is a pytest.fixture, we cannot take advantage + # of pytest.importorskip. Using the block below instead. + try: + import tflite + except ImportError: + print("Cannot import tflite, which is required by tflite_compiled_module_as_tarfile.") + return "" + target_dir = tmpdir_factory.mktemp("data") graph, lib, params, _ = get_sample_compiled_module(target_dir) @@ -117,3 +131,20 @@ def tflite_compiled_module_as_tarfile(tmpdir_factory): tvmc.compiler.save_module(module_file, graph, lib, params) return module_file + + +@pytest.fixture(scope="session") +def imagenet_cat(tmpdir_factory): + tmpdir_name = tmpdir_factory.mktemp("data") + cat_file_name = "imagenet_cat.npz" + + cat_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true" + image_path = download_testdata(cat_url, "inputs", module=["tvmc"]) + resized_image = Image.open(image_path).resize((224, 224)) + image_data = np.asarray(resized_image).astype("float32") + image_data = np.expand_dims(image_data, axis=0) + + cat_file_full_path = os.path.join(tmpdir_name, cat_file_name) + np.savez(cat_file_full_path, input=image_data) + + return cat_file_full_path diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_common.py index a9a62c5ef874f..5ffbc6fe37dd7 100644 --- a/tests/python/driver/tvmc/test_common.py +++ b/tests/python/driver/tvmc/test_common.py @@ -118,3 +118,34 @@ def _is_layout_transform(node): tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" + + +def test_tracker_host_port_from_cli__hostname_port(): + input_str = "1.2.3.4:9090" + expected_host = "1.2.3.4" + expected_port = 9090 + + actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + + assert expected_host == actual_host + assert expected_port == actual_port + + +def test_tracker_host_port_from_cli__hostname_port__empty(): + input_str = "" + + actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + + assert actual_host is None + assert actual_port is None + + +def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090(): + input_str = "1.2.3.4" + expected_host = "1.2.3.4" + expected_port = 9090 + + actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + + assert expected_host == actual_host + assert expected_port == actual_port diff --git a/tests/python/driver/tvmc/test_runner.py b/tests/python/driver/tvmc/test_runner.py new file mode 100644 index 0000000000000..544ed9f7e9df4 --- /dev/null +++ b/tests/python/driver/tvmc/test_runner.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import numpy as np + +from tvm.driver import tvmc + + +def test_generate_tensor_data_zeros(): + expected_shape = (2, 3) + expected_dtype = "uint8" + sut = tvmc.runner.generate_tensor_data(expected_shape, expected_dtype, "zeros") + + assert sut.shape == (2, 3) + + +def test_generate_tensor_data_ones(): + expected_shape = (224, 224) + expected_dtype = "uint8" + sut = tvmc.runner.generate_tensor_data(expected_shape, expected_dtype, "ones") + + assert sut.shape == (224, 224) + + +def test_generate_tensor_data_random(): + expected_shape = (2, 3) + expected_dtype = "uint8" + sut = tvmc.runner.generate_tensor_data(expected_shape, expected_dtype, "random") + + assert sut.shape == (2, 3) + + +def test_generate_tensor_data__type_unknown(): + with pytest.raises(tvmc.common.TVMCException) as e: + tvmc.runner.generate_tensor_data((2, 3), "float32", "whatever") + + +def test_format_times__contains_header(): + sut = tvmc.runner.format_times([60, 120, 12, 42]) + assert "std (s)" in sut + + +def test_get_top_results_keep_results(): + fake_outputs = {"output_0": np.array([[1, 2, 3, 4], [5, 6, 7, 8]])} + number_of_results_wanted = 3 + sut = tvmc.runner.get_top_results(fake_outputs, number_of_results_wanted) + + expected_number_of_lines = 2 + assert len(sut) == expected_number_of_lines + + expected_number_of_results_per_line = 3 + assert len(sut[0]) == expected_number_of_results_per_line + assert len(sut[1]) == expected_number_of_results_per_line + + +def test_run_tflite_module__with_profile__valid_input( + tflite_compiled_module_as_tarfile, imagenet_cat +): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip("tflite") + + outputs, times = tvmc.runner.run_module( + tflite_compiled_module_as_tarfile, + inputs_file=imagenet_cat, + hostname=None, + device="cpu", + profile=True, + ) + + # collect the top 5 results + top_5_results = tvmc.runner.get_top_results(outputs, 5) + top_5_ids = top_5_results[0] + + # IDs were collected from this reference: + # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/ + # java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt + tiger_cat_mobilenet_id = 283 + + assert ( + tiger_cat_mobilenet_id in top_5_ids + ), "tiger cat is expected in the top-5 for mobilenet v1" + assert type(outputs) is dict + assert type(times) is tuple + assert "output_0" in outputs.keys()