diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 3f1d04aee7fd..dcb770b9a563 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -47,6 +47,11 @@ def add_compile_parser(subparsers): default="", help="the cross compiler to generate target libraries, e.g. 'aarch64-linux-gnu-gcc'", ) + parser.add_argument( + "--cross-compiler-options", + default="", + help="the cross compiler options to generate target libraries, e.g. '-mfpu=neon-vfpv4'", + ) parser.add_argument( "--desired-layout", choices=["NCHW", "NHWC"], @@ -126,6 +131,7 @@ def drive_compile(args): tuning_records=args.tuning_records, package_path=args.output, cross=args.cross_compiler, + cross_options=args.cross_compiler_options, dump_code=dump_code, target_host=None, desired_layout=args.desired_layout, @@ -141,6 +147,7 @@ def compile_model( tuning_records: Optional[str] = None, package_path: Optional[str] = None, cross: Optional[Union[str, Callable]] = None, + cross_options: Optional[str] = None, export_format: str = "so", dump_code: Optional[List[str]] = None, target_host: Optional[str] = None, @@ -168,6 +175,8 @@ def compile_model( be saved in a temporary directory. cross : str or callable object, optional Function that performs the actual compilation + cross_options : str, optional + Command line options to be passed to the cross compiler. export_format : str What format to use when saving the function library. Must be one of "so" or "tar". When compiling for a remote device without a cross compiler, "tar" will likely work better. @@ -252,7 +261,9 @@ def compile_model( dumps[source_type] = source # Create a new tvmc model package object from the graph definition. - package_path = tvmc_model.export_package(graph_module, package_path, cross, export_format) + package_path = tvmc_model.export_package( + graph_module, package_path, cross, cross_options, export_format + ) # Write dumps to file. if dumps: diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index e48125f0f619..26a1e3600b96 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -180,6 +180,7 @@ def export_package( executor_factory: GraphExecutorFactoryModule, package_path: Optional[str] = None, cross: Optional[Union[str, Callable]] = None, + cross_options: Optional[str] = None, lib_format: str = "so", ): """Save this TVMCModel to file. @@ -192,6 +193,8 @@ def export_package( If not provided, the package will be saved to a generically named file in tmp. cross : str or callable object, optional Function that performs the actual compilation. + cross_options : str, optional + Command line options to be passed to the cross compiler. lib_format : str How to export the modules function library. Must be one of "so" or "tar". @@ -214,9 +217,14 @@ def export_package( if not cross: executor_factory.get_lib().export_library(path_lib) else: - executor_factory.get_lib().export_library( - path_lib, tvm.contrib.cc.cross_compiler(cross) - ) + if not cross_options: + executor_factory.get_lib().export_library( + path_lib, tvm.contrib.cc.cross_compiler(cross) + ) + else: + executor_factory.get_lib().export_library( + path_lib, tvm.contrib.cc.cross_compiler(cross, options=cross_options.split(" ")) + ) self.lib_path = path_lib with open(temp.relpath(graph_name), "w") as graph_file: diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index a023689cc86d..16c02335c8a0 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -29,6 +29,8 @@ from tvm.driver import tvmc from tvm.driver.tvmc.model import TVMCPackage +from tvm.contrib import utils + def test_save_dumps(tmpdir_factory): tmpdir = tmpdir_factory.mktemp("data") @@ -92,6 +94,33 @@ def test_cross_compile_aarch64_tflite_module(tflite_mobilenet_v1_1_quant): assert os.path.exists(dumps_path) +# This test will be skipped if the AArch64 cross-compilation toolchain is not installed. +@pytest.mark.skipif( + not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" +) +def test_cross_compile_options_aarch64_tflite_module(tflite_mobilenet_v1_1_quant): + pytest.importorskip("tflite") + + fake_sysroot_dir = utils.tempdir().relpath("") + + tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) + tvmc_package = tvmc.compile( + tvmc_model, + target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr='+neon'", + dump_code="asm", + cross="aarch64-linux-gnu-gcc", + cross_options="--sysroot=" + fake_sysroot_dir, + ) + dumps_path = tvmc_package.package_path + ".asm" + + # check for output types + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) + + def test_compile_keras__save_module(keras_resnet50, tmpdir_factory): # some CI environments wont offer tensorflow/Keras, so skip in case it is not present pytest.importorskip("tensorflow") @@ -137,6 +166,34 @@ def test_cross_compile_aarch64_keras_module(keras_resnet50): assert os.path.exists(dumps_path) +# This test will be skipped if the AArch64 cross-compilation toolchain is not installed. +@pytest.mark.skipif( + not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" +) +def test_cross_compile_options_aarch64_keras_module(keras_resnet50): + # some CI environments wont offer tensorflow/Keras, so skip in case it is not present + pytest.importorskip("tensorflow") + + fake_sysroot_dir = utils.tempdir().relpath("") + + tvmc_model = tvmc.load(keras_resnet50) + tvmc_package = tvmc.compile( + tvmc_model, + target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr='+neon'", + dump_code="asm", + cross="aarch64-linux-gnu-gcc", + cross_options="--sysroot=" + fake_sysroot_dir, + ) + dumps_path = tvmc_package.package_path + ".asm" + + # check for output types + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) + + def verify_compile_onnx_module(model, shape_dict=None): # some CI environments wont offer onnx, so skip in case it is not present pytest.importorskip("onnx") @@ -186,6 +243,34 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50): assert os.path.exists(dumps_path) +# This test will be skipped if the AArch64 cross-compilation toolchain is not installed. +@pytest.mark.skipif( + not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" +) +def test_cross_compile_options_aarch64_onnx_module(onnx_resnet50): + # some CI environments wont offer onnx, so skip in case it is not present + pytest.importorskip("onnx") + + fake_sysroot_dir = utils.tempdir().relpath("") + + tvmc_model = tvmc.load(onnx_resnet50) + tvmc_package = tvmc.compile( + tvmc_model, + target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + dump_code="asm", + cross="aarch64-linux-gnu-gcc", + cross_options="--sysroot=" + fake_sysroot_dir, + ) + dumps_path = tvmc_package.package_path + ".asm" + + # check for output types + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) + + @tvm.testing.requires_opencl def test_compile_opencl(tflite_mobilenet_v1_0_25_128): pytest.importorskip("tflite")