Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow tvmc to compile models with AOT executor in MLF #8331

Merged
merged 3 commits into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def compile_model(
target_host: Optional[str] = None,
desired_layout: Optional[str] = None,
disabled_pass: Optional[str] = None,
pass_context_configs: Optional[str] = None,
pass_context_configs: Optional[List[str]] = None,
):
"""Compile a model from a supported framework into a TVM module.

Expand Down Expand Up @@ -212,8 +212,8 @@ def compile_model(
disabled_pass: str, optional
Comma-separated list of passes which needs to be disabled
during compilation
pass_context_configs: str, optional
String containing a set of configurations to be passed to the
pass_context_configs: list[str], optional
List of strings containing a set of configurations to be passed to the
PassContext.


Expand Down
13 changes: 10 additions & 3 deletions python/tvm/driver/tvmc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"""
import os
import tarfile
import json
from typing import Optional, Union, List, Dict, Callable, TextIO
import numpy as np

Expand Down Expand Up @@ -332,8 +333,11 @@ def import_package(self, package_path: str):
# Model Library Format (MLF)
self.lib_name = None
self.lib_path = None
with open(temp.relpath("metadata.json")) as metadata_json:
metadata = json.load(metadata_json)

graph = temp.relpath("runtime-config/graph/graph.json")
is_graph_runtime = "graph" in metadata["runtimes"]
graph = temp.relpath("runtime-config/graph/graph.json") if is_graph_runtime else None
params = temp.relpath("parameters/default.params")

self.type = "mlf"
Expand All @@ -357,8 +361,11 @@ def import_package(self, package_path: str):
with open(params, "rb") as param_file:
self.params = bytearray(param_file.read())

with open(graph) as graph_file:
self.graph = graph_file.read()
if graph is not None:
with open(graph) as graph_file:
self.graph = graph_file.read()
else:
self.graph = None


class TVMCResult(object):
Expand Down
60 changes: 9 additions & 51 deletions tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,6 @@ def download_and_untar(model_url, model_sub_path, temp_dir):
return os.path.join(temp_dir, model_sub_path)


def get_sample_compiled_module(target_dir, package_filename, output_format="so"):
"""Support function that returns a TFLite compiled module"""
base_url = "https://storage.googleapis.com/download.tensorflow.org/models"
model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
model_file = download_and_untar(
"{}/{}".format(base_url, model_url),
"mobilenet_v1_1.0_224_quant.tflite",
temp_dir=target_dir,
)

tvmc_model = tvmc.frontends.load_model(model_file)
return tvmc.compiler.compile_model(
tvmc_model,
target="llvm",
package_path=os.path.join(target_dir, package_filename),
output_format=output_format,
)


# PyTest fixtures


Expand Down Expand Up @@ -167,40 +148,17 @@ def onnx_mnist():
return model_file


@pytest.fixture(scope="session")
def tflite_compiled_model(tmpdir_factory):

# Not all CI environments will have TFLite installed
# so we need to safely skip this fixture that will
# crash the tests that rely on it.
# As this is a pytest.fixture, we cannot take advantage
# of pytest.importorskip. Using the block below instead.
try:
import tflite
except ImportError:
print("Cannot import tflite, which is required by tflite_compiled_module_as_tarfile.")
return ""

target_dir = tmpdir_factory.mktemp("data")
return get_sample_compiled_module(target_dir, "mock.tar")


@pytest.fixture(scope="session")
def tflite_compiled_model_mlf(tmpdir_factory):
@pytest.fixture
def tflite_compile_model(tmpdir_factory):
"""Support function that returns a TFLite compiled module"""

# Not all CI environments will have TFLite installed
# so we need to safely skip this fixture that will
# crash the tests that rely on it.
# As this is a pytest.fixture, we cannot take advantage
# of pytest.importorskip. Using the block below instead.
try:
import tflite
except ImportError:
print("Cannot import tflite, which is required by tflite_compiled_module_as_tarfile.")
return ""
def model_compiler(model_file, **overrides):
package_path = tmpdir_factory.mktemp("data").join("mock.tar")
tvmc_model = tvmc.frontends.load_model(model_file)
args = {"target": "llvm", **overrides}
return tvmc.compiler.compile_model(tvmc_model, package_path=package_path, **args)

target_dir = tmpdir_factory.mktemp("data")
return get_sample_compiled_module(target_dir, "mock.tar", "mlf")
return model_compiler


@pytest.fixture(scope="session")
Expand Down
47 changes: 40 additions & 7 deletions tests/python/driver/tvmc/test_mlf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,28 @@

import pytest
import os
import shlex

import tvm
from tvm.driver import tvmc
from tvm.driver.tvmc.main import _main
from tvm.driver.tvmc.model import TVMCPackage, TVMCException


def test_tvmc_cl_compile_run_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory):
@pytest.mark.parametrize(
"target,pass_configs", [["llvm", []], ["c --executor=aot", ["tir.disable_vectorize=1"]]]
)
def test_tvmc_cl_compile_run_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory, target, pass_configs):
pytest.importorskip("tflite")

output_dir = tmpdir_factory.mktemp("mlf")
input_model = tflite_mobilenet_v1_1_quant
output_file = os.path.join(output_dir, "mock.tar")

# Compile the input model and generate a Model Library Format (MLF) archive.
tvmc_cmd = (
f"tvmc compile {input_model} --target='llvm' --output {output_file} --output-format mlf"
)
tvmc_args = tvmc_cmd.split(" ")[1:]
pass_config_args = " ".join([f"--pass-config {pass_config}" for pass_config in pass_configs])
tvmc_cmd = f"tvmc compile {input_model} --target='{target}' {pass_config_args} --output {output_file} --output-format mlf"
tvmc_args = shlex.split(tvmc_cmd)[1:]
_main(tvmc_args)
assert os.path.exists(output_file), "Could not find the exported MLF archive."

Expand Down Expand Up @@ -82,9 +85,39 @@ def test_tvmc_export_package_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory):
assert str(exp.value) == expected_reason, on_error


def test_tvmc_import_package_mlf(tflite_compiled_model_mlf):
def test_tvmc_import_package_mlf_graph(tflite_mobilenet_v1_1_quant, tflite_compile_model):
pytest.importorskip("tflite")

tflite_compiled_model_mlf = tflite_compile_model(
tflite_mobilenet_v1_1_quant, output_format="mlf"
)

# Compile and export a model to a MLF archive so it can be imported.
exported_tvmc_package = tflite_compiled_model_mlf
archive_path = exported_tvmc_package.package_path

# Import the MLF archive. TVMCPackage constructor will call import_package method.
tvmc_package = TVMCPackage(archive_path)

assert tvmc_package.lib_name is None, ".lib_name must not be set in the MLF archive."
assert tvmc_package.lib_path is None, ".lib_path must not be set in the MLF archive."
assert (
tvmc_package.graph is not None
), ".graph must be set in the MLF archive for Graph executor."
assert tvmc_package.params is not None, ".params must be set in the MLF archive."
assert tvmc_package.type == "mlf", ".type must be set to 'mlf' in the MLF format."


def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant, tflite_compile_model):
pytest.importorskip("tflite")

tflite_compiled_model_mlf = tflite_compile_model(
tflite_mobilenet_v1_1_quant,
target="c --executor=aot",
output_format="mlf",
pass_context_configs=["tir.disable_vectorize=1"],
)

# Compile and export a model to a MLF archive so it can be imported.
exported_tvmc_package = tflite_compiled_model_mlf
archive_path = exported_tvmc_package.package_path
Expand All @@ -94,6 +127,6 @@ def test_tvmc_import_package_mlf(tflite_compiled_model_mlf):

assert tvmc_package.lib_name is None, ".lib_name must not be set in the MLF archive."
assert tvmc_package.lib_path is None, ".lib_path must not be set in the MLF archive."
assert tvmc_package.graph is not None, ".graph must be set in the MLF archive."
assert tvmc_package.graph is None, ".graph must not be set in the MLF archive for AOT executor."
assert tvmc_package.params is not None, ".params must be set in the MLF archive."
assert tvmc_package.type == "mlf", ".type must be set to 'mlf' in the MLF format."
5 changes: 4 additions & 1 deletion tests/python/driver/tvmc/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,15 @@ def test_get_top_results_keep_results():
assert len(sut[1]) == expected_number_of_results_per_line


def test_run_tflite_module__with_profile__valid_input(tflite_compiled_model, imagenet_cat):
def test_run_tflite_module__with_profile__valid_input(
tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat
):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")

inputs = np.load(imagenet_cat)

tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant)
result = tvmc.run(
tflite_compiled_model,
inputs=inputs,
Expand Down