diff --git a/docker/install/ubuntu_install_vitis_ai_core.sh b/docker/install/ubuntu_install_vitis_ai_core.sh old mode 100644 new mode 100755 diff --git a/docs/deploy/vitis_ai.rst b/docs/deploy/vitis_ai.rst index 1ce89ebed9c2..d3e3ca004f7e 100755 --- a/docs/deploy/vitis_ai.rst +++ b/docs/deploy/vitis_ai.rst @@ -196,7 +196,7 @@ Hardware setup and docker build pip3 install -e . --user Edge (DPUCZDX8G) -^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~ For edge deployment we make use of two systems referred to as host and @@ -435,8 +435,8 @@ Cloud usage This section shows how to accelerate a convolutional neural network model in TVM with Vitis-AI on the cloud. -To be able to target the Vitis-AI cloud DPUCADX8G target we first have -to import the target in PyXIR. This PyXIR package is the interface being +To be able to target the Vitis-AI cloud DPUCADX8G we first have +to import the DPU target in PyXIR. This PyXIR package is the interface being used by TVM to integrate with the Vitis-AI stack. Additionaly, import the typical TVM and Relay modules and the Vitis-AI contrib module inside TVM. @@ -451,32 +451,29 @@ TVM. from tvm.contrib.target import vitis_ai from tvm.contrib import utils, graph_executor from tvm.relay.build_module import bind_params_by_name - from tvm.relay.op.contrib.vitis_ai import annotation + from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai After importing a convolutional neural network model using the usual Relay API's, annotate the Relay expression for the given Vitis-AI DPU target and partition the graph. .. code:: python - - mod["main"] = bind_params_by_name(mod["main"], params) - mod = annotation(mod, params, target) - mod = relay.transform.MergeCompilerRegions()(mod) - mod = relay.transform.PartitionGraph()(mod) + + dpu = 'DPUCADX8G' + mod = partition_for_vitis_ai(mod, params, dpu) Now, we can build the TVM runtime library for executing the model. The TVM target is 'llvm' as the operations that can't be handled by the DPU -are executed on the CPU. The Vitis-AI target is DPUCADX8G as we are -targeting the cloud DPU and this target is passed as a config to the TVM +are executed on the CPU. The Vitis-AI DPU is DPUCADX8G as we are +targeting the cloud DPU and this DPU identifier is passed as a config to the TVM build call. .. code:: python - tvm_target = 'llvm' - target='DPUCADX8G' + target = 'llvm' - with tvm.transform.PassContext(opt_level=3, config= {'relay.ext.vitis_ai.options.target': target}): - lib = relay.build(mod, tvm_target, params=params) + with tvm.transform.PassContext(opt_level=3, config= {'relay.ext.vitis_ai.options': {'dpu': dpu}}): + lib = relay.build(mod, target, params=params) As one more step before we can accelerate a model with Vitis-AI in TVM we have to quantize and compile the model for execution on the DPU. We @@ -537,8 +534,8 @@ A complete ResNet 18 example can be found `here `__. Additionally, we +DPU's, see `edge DPU's info <#edge-requirements>`__. Additionally, we provide the 'export_runtime_module' config that points to a file to which we can export the Vitis-AI runtime module. We have to do this because we will first be compiling and quantizing the model on the host machine before building @@ -617,13 +613,15 @@ can be included. .. code:: python - tvm_target = 'llvm' - target='DPUCZDX8G-zcu104' + target = 'llvm' export_rt_mod_file = "vitis_ai.rtmod" - - with tvm.transform.PassContext(opt_level=3, config= {'relay.ext.vitis_ai.options.target': target, - 'relay.ext.vitis_ai.options.export_runtime_module': export_rt_mod_file}): - lib = relay.build(mod, tvm_target, params=params) + + build_options = { + 'dpu': dpu, + 'export_runtime_module': export_rt_mod_file + } + with tvm.transform.PassContext(opt_level=3, config= {'relay.ext.vitis_ai.options': build_options}): + lib = relay.build(mod, target, params=params) We will quantize and compile the model for execution on the DPU using on-the-fly quantization on the host machine. This makes use of TVM inference calls @@ -658,15 +656,17 @@ in the TVM build. .. code:: python # Export lib for aarch64 target - tvm_target = tvm.target.arm_cpu('ultra96') + target = tvm.target.arm_cpu('ultra96') lib_kwargs = { 'fcompile': contrib.cc.create_shared, 'cc': "/usr/aarch64-linux-gnu/bin/ld" } - - with tvm.transform.PassContext(opt_level=3, - config={'relay.ext.vitis_ai.options.load_runtime_module': export_rt_mod_file}): - lib_arm = relay.build(mod, tvm_target, params=params) + + build_options = { + 'load_runtime_module': export_rt_mod_file + } + with tvm.transform.PassContext(opt_level=3, config={'relay.ext.vitis_ai.options': build_options}): + lib_arm = relay.build(mod, target, params=params) lib_dpuv2.export_library('tvm_dpu_arm.so', **lib_kwargs) @@ -688,7 +688,7 @@ as root (execute ``su`` in terminal to log into root). You will see a warning about the 'cpu-tf' runtime not being found. This warning is expected on the board and can be ignored. Note also that you **shouldn't** import the - PyXIR targets in the run script (``import pyxir.contrib.target.DPUCZDX8G``). + PyXIR DPU targets in the run script (``import pyxir.contrib.target.DPUCZDX8G``). .. code:: python diff --git a/python/tvm/contrib/target/vitis_ai.py b/python/tvm/contrib/target/vitis_ai.py index f319fd799829..837e6604bb4c 100644 --- a/python/tvm/contrib/target/vitis_ai.py +++ b/python/tvm/contrib/target/vitis_ai.py @@ -19,30 +19,86 @@ """Utility to offload (sub-)models to Vitis-AI""" import warnings - -import pyxir -import pyxir.frontend.tvm +import importlib from tvm.relay.expr import Tuple, Call, TupleGetItem import tvm._ffi +# Placeholder for PyXIR module +pyxir = None + + +def vitis_ai_available(): + """Return whether Vitis AI tools are available""" + pyxir_spec = importlib.util.find_spec("pyxir") + if not tvm.get_global_func("tvm.vitis_ai_runtime.from_xgraph", True) or pyxir_spec is None: + return False + return True + class CodegenVitisAI: - """Traverse Relay expression and convert into PyXIR XGraph format""" + """Traverse Relay expression and convert into PyXIR XGraph format + + Parameters + ---------- + function : Function + The Relay function + dpu_target : str + The Vitis AI DPU target identifier + """ + + def __init__(self, function, dpu_target): + global pyxir + try: + if pyxir is None: + pyxir = __import__("pyxir") + __import__("pyxir.frontend.tvm") + except ImportError: + # add "from None" to silence + # "During handling of the above exception, another exception occurred" + raise ImportError( + "The pyxir package is required for the Vitis AI backend. " + "Please install it first. " + "Help: (https://tvm.apache.org/docs/deploy/vitis_ai.html) " + ) from None - def __init__(self, model_name, function): - self.model_name = model_name self.function = function + self.dpu_target = dpu_target self.params = {} - def convert_pyxir(self, target): - """Convert Relay expression to PyXIR XGraph""" + def build(self): + """ "Convert the Relay expression to a PyXIR XGraph to instantiate + the Vitis AI runtime + + Returns + ------- + xgraph_str : str + Serialized XGraph + """ xgraph = pyxir.frontend.tvm.from_relay( self.function, params=self.params, postprocessing=None ) - xgraph = pyxir.partition(xgraph, targets=[target]) - return xgraph + xgraph = pyxir.partition(xgraph, targets=[self.dpu_target]) + output_relay_ids = self.get_output_names() + layers = xgraph.get_layers() + + # Get the output tensor names using XGraph and output Relay ids + out_tensor_names = ["unknown_name"] * len(output_relay_ids) + for layer in layers: + if not layer.internal: + for relay_id in layer.attrs["relay_id"]: + if relay_id in output_relay_ids: + out_tensor_names[output_relay_ids.index(relay_id)] = layer.name + break + if any([name == "unkown_name" for name in out_tensor_names]): + raise ValueError( + "During codegeneration the loading of subexpression" + " failed due to output tensor name mismatch in Relay PyXIR interface." + ) + xgraph.meta_attrs["tvm_out_tensors"] = out_tensor_names + xgraph_str = pyxir.get_xgraph_str(xgraph) + return xgraph_str def get_output_names(self): """Get output names from Relay expression""" @@ -66,49 +122,73 @@ def vitis_ai_compiler(ref): """Create a Vitis-AI runtime from the provided Relay expression""" assert isinstance(ref, tvm.relay.function.Function) - out_tensor_names = [] name = str(ref.attrs.global_symbol) pass_context = tvm.get_global_func("transform.GetCurrentPassContext")() - # The target Vitis-AI accelerator device - target = ( - str(pass_context.config["relay.ext.vitis_ai.options.target"]) - if "relay.ext.vitis_ai.options.target" in pass_context.config + cfg = ( + pass_context.config["relay.ext.vitis_ai.options"] + if "relay.ext.vitis_ai.options" in pass_context.config else None ) - # (Optional configs) The build and work directories to be used by Vitis-AI - vai_build_dir = ( - str(pass_context.config["relay.ext.vitis_ai.options.build_dir"]) - if "relay.ext.vitis_ai.options.build_dir" in pass_context.config - else tvm.contrib.utils.tempdir().relpath("") - ) - vai_work_dir = ( - str(pass_context.config["relay.ext.vitis_ai.options.work_dir"]) - if "relay.ext.vitis_ai.options.work_dir" in pass_context.config - else tvm.contrib.utils.tempdir().relpath("") - ) + # Backward compatibility with old pass context configs + if cfg is None: + warnings.warn( + "You are using a deprecated way of passing build configs (e.g." + " `relay.ext.vitis_ai.options.target`). Check out the Vitis AI " + " documentation here: https://tvm.apache.org/docs/deploy/vitis_ai.html" + " to switch to recommended way for passing build configs." + ) - # (Optional configs) Export and load PyXIR runtime module to file if provided. This is used to - # compile and quantize a model on the host and deploy it at the edge - export_runtime_module = ( - str(pass_context.config["relay.ext.vitis_ai.options.export_runtime_module"]) - if "relay.ext.vitis_ai.options.export_runtime_module" in pass_context.config - else "" - ) - load_runtime_module = ( - str(pass_context.config["relay.ext.vitis_ai.options.load_runtime_module"]) - if "relay.ext.vitis_ai.options.load_runtime_module" in pass_context.config - else "" - ) + # The target Vitis-AI accelerator device + dpu_target = ( + str(pass_context.config["relay.ext.vitis_ai.options.target"]) + if "relay.ext.vitis_ai.options.target" in pass_context.config + else None + ) + + # (Optional configs) The build and work directories to be used by Vitis-AI + vai_build_dir = ( + str(pass_context.config["relay.ext.vitis_ai.options.build_dir"]) + if "relay.ext.vitis_ai.options.build_dir" in pass_context.config + else tvm.contrib.utils.tempdir().relpath("") + ) + vai_work_dir = ( + str(pass_context.config["relay.ext.vitis_ai.options.work_dir"]) + if "relay.ext.vitis_ai.options.work_dir" in pass_context.config + else tvm.contrib.utils.tempdir().relpath("") + ) + + # (Optional configs) Export and load PyXIR runtime module to file if provided. This is + # used to compile and quantize a model on the host and deploy it at the edge + export_runtime_module = ( + str(pass_context.config["relay.ext.vitis_ai.options.export_runtime_module"]) + if "relay.ext.vitis_ai.options.export_runtime_module" in pass_context.config + else "" + ) + load_runtime_module = ( + str(pass_context.config["relay.ext.vitis_ai.options.load_runtime_module"]) + if "relay.ext.vitis_ai.options.load_runtime_module" in pass_context.config + else "" + ) + else: + dpu_target = cfg.dpu if cfg.dpu else None + # (Optional configs) The build and work directories to be used by Vitis AI + vai_build_dir = cfg.build_dir if cfg.build_dir else tvm.contrib.utils.tempdir().relpath("") + + # (Optional configs) Export and load PyXIR runtime module to file if provided. This is + # used to compile and quantize a model on the host and deploy it at the edge + vai_work_dir = cfg.work_dir if cfg.work_dir else tvm.contrib.utils.tempdir().relpath("") + export_runtime_module = cfg.export_runtime_module + load_runtime_module = cfg.load_runtime_module # Config checks - if load_runtime_module and target is not None: + if load_runtime_module and dpu_target is not None: warnings.warn( - "Both `load_runtime_module` and `target` configs were specified." + "Both `load_runtime_module` and `dpu` configs were specified." " The `load_runtime_module` points to a prebuilt runtime module with" - " an internal target so the `target` config will be ignored" + " an internal DPU target so the `dpu` config will be ignored" ) if load_runtime_module and "relay.ext.vitis_ai.options.build_dir" in pass_context.config: warnings.warn( @@ -126,30 +206,14 @@ def vitis_ai_compiler(ref): # If load_runtime_module is not set, we will build the PyXIR runtime module from scratch if load_runtime_module == "": # Convert Relay expression into XGraph and do partitioning inside PyXIR - builder = CodegenVitisAI(name, ref) - xgraph = builder.convert_pyxir(target) - output_relay_ids = builder.get_output_names() - layers = xgraph.get_layers() - - # Get the output tensor names using XGraph and output Relay ids - out_tensor_names = ["unknown_name"] * len(output_relay_ids) - for layer in layers: - if not layer.internal: - for relay_id in layer.attrs["relay_id"]: - if relay_id in output_relay_ids: - out_tensor_names[output_relay_ids.index(relay_id)] = layer.name - break - if any([name == "unkown_name" for name in out_tensor_names]): - raise ValueError( - "During codegeneration the loading of subexpression \ - failed due to output tensor name mismatch in Relay PyXIR interface." - ) - xgraph.meta_attrs["tvm_out_tensors"] = out_tensor_names - xgraph_str = pyxir.get_xgraph_str(xgraph) + codegen = CodegenVitisAI(ref, dpu_target) + xgraph_str = codegen.build() runtime_func = "tvm.vitis_ai_runtime.from_xgraph" fcreate = tvm._ffi.get_global_func(runtime_func) - return fcreate(name, xgraph_str, target, vai_build_dir, vai_work_dir, export_runtime_module) + return fcreate( + name, xgraph_str, dpu_target, vai_build_dir, vai_work_dir, export_runtime_module + ) runtime_func = "tvm.vitis_ai_runtime.from_rt_mod" fcreate = tvm._ffi.get_global_func(runtime_func) diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index 99ed11789364..bdb4c6200b98 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -250,7 +250,7 @@ def drive_tune(args): for codegen_from_cli in extra_targets: codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"]) partition_function = codegen["pass_pipeline"] - mod = partition_function(mod, params) + mod = partition_function(mod, params, **codegen_from_cli["opts"]) # min_repeat_ms should be: # a. the value provided by the user, if any, or diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index ba7722f6b38e..b8450750f115 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -193,7 +193,7 @@ def compile_model( for codegen_from_cli in extra_targets: codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"]) partition_function = codegen["pass_pipeline"] - mod = partition_function(mod, params) + mod = partition_function(mod, params, **codegen_from_cli["opts"]) if codegen["config_key"] is not None: config[codegen["config_key"]] = codegen_from_cli["opts"] diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index 886160ad000c..ac1a41a0c4a9 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -19,9 +19,14 @@ """ import logging +# Make sure Vitis AI codegen is registered +import tvm.contrib.target.vitis_ai # pylint: disable=unused-import + from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib from tvm.relay.op.contrib.ethosn import partition_for_ethosn from tvm.relay.op.contrib.bnns import partition_for_bnns +from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai + from .common import TVMCException @@ -29,9 +34,16 @@ # pylint: disable=invalid-name logger = logging.getLogger("TVMC") -# Global dictionary to map targets with the configuration key -# to be used in the PassContext (if any), and a function -# responsible for partitioning to that target. + +# Global dictionary to map targets +# +# Options +# ------- +# config_key : str +# The configuration key to be used in the PassContext (if any). +# pass_pipeline : Callable +# A function to transform a Module before compilation, mainly used +# for partitioning for the target currently. REGISTERED_CODEGEN = { "compute-library": { "config_key": None, @@ -45,6 +57,10 @@ "config_key": None, "pass_pipeline": partition_for_bnns, }, + "vitis-ai": { + "config_key": "relay.ext.vitis_ai.options", + "pass_pipeline": partition_for_vitis_ai, + }, } @@ -62,10 +78,15 @@ def get_codegen_names(): def get_codegen_by_target(name): """Return a codegen entry by name. + Parameters + ---------- + name : str + The name of the target for which the codegen info should be retrieved. + Returns ------- dict - requested target information + requested target codegen information """ try: return REGISTERED_CODEGEN[name] diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 17fdbf941e08..6234c944d4e4 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -43,7 +43,7 @@ def is_arm_compute_runtime_enabled(): return False -def partition_for_arm_compute_lib(mod, params=None): +def partition_for_arm_compute_lib(mod, params=None, **opts): """Partition the graph greedily offloading supported operators to Arm Compute Library. diff --git a/python/tvm/relay/op/contrib/ethosn.py b/python/tvm/relay/op/contrib/ethosn.py index 478a1ec46f26..2c63d63a36ef 100644 --- a/python/tvm/relay/op/contrib/ethosn.py +++ b/python/tvm/relay/op/contrib/ethosn.py @@ -46,7 +46,7 @@ def ethosn_available(): return Available.SW_AND_HW if hw else Available.SW_ONLY -def partition_for_ethosn(mod, params=None): +def partition_for_ethosn(mod, params=None, **opts): """Partition the graph greedily offloading supported operators to Arm Ethos-N NPU. diff --git a/python/tvm/relay/op/contrib/vitis_ai.py b/python/tvm/relay/op/contrib/vitis_ai.py index aaa9f99e61ed..0c05c8db7435 100644 --- a/python/tvm/relay/op/contrib/vitis_ai.py +++ b/python/tvm/relay/op/contrib/vitis_ai.py @@ -17,25 +17,52 @@ # pylint: disable=invalid-name, unused-argument, no-else-return, E1102 """Vitis-AI codegen annotation of supported operators""" +import warnings import numpy as np -import pyxir -import pyxir.frontend.tvm - from tvm import relay import tvm._ffi -from tvm.relay.expr import Tuple, TupleGetItem from tvm.relay import transform +from tvm.relay.expr import Tuple, TupleGetItem +from tvm.relay.build_module import bind_params_by_name from tvm.relay.op.annotation import compiler_begin, compiler_end +# Placeholder for PyXIR module +pyxir = None + @transform.function_pass(opt_level=0) class VitisAIAnnotationPass: - """Responsible for annotating Relay expressions for Vitis-AI DPU accelerators""" + """Responsible for annotating Relay expressions for Vitis-AI DPU accelerators + + Parameters + ---------- + compiler : str + The compiler name used for annotations (`vitis_ai`). + dpu_target : str + The Vitis AI DPU target identifier. + params : dict + A dictionary containing the module's parameters. + """ + + def __init__(self, compiler, dpu_target, params): + global pyxir + try: + if pyxir is None: + pyxir = __import__("pyxir") + __import__("pyxir.frontend.tvm") + except ImportError: + # add "from None" to silence + # "During handling of the above exception, another exception occurred" + raise ImportError( + "The pyxir package is required for the Vitis AI backend. " + "Please install it first. " + "Help: (https://tvm.apache.org/docs/deploy/vitis_ai.html) " + ) from None - def __init__(self, compiler, relay_ids): self.compiler = compiler - self.relay_ids = relay_ids + self.dpu_target = dpu_target + self.params = params def transform_function(self, func, mod, ctx): """Transform function for annotating Relay module""" @@ -80,25 +107,68 @@ def visit_call(self, call): else: return super().visit_call(call) + xgraph = pyxir.frontend.tvm.from_relay(mod, self.params, postprocessing=None) + xgraph = pyxir.partition(xgraph, targets=[self.dpu_target]) + + layers = xgraph.get_layers() + relay_ids = [ + list(np.array(layer.attrs["relay_id"]).flatten()) + for layer in layers + if layer.target == self.dpu_target + ] + self.relay_ids = [item for sublist in relay_ids for item in sublist] + return Annotator().visit(func) def annotation(mod, params, target): - """Annotate Relay expression for Vitis-AI DPU accelerators""" + """DEPRECATED + + Annotate Relay expression for offloading operators to Vitis AI DPU accelerators + NOTE: This function does the same as the next one (`partition_for_vitis_ai`) but is + still here for backward compatibility""" # We need type information for supporting models that contain operations that don't # have a Relay to XLayer translation + warnings.warn( + "tvm.relay.op.contrib.vitis_ai.annotation() is being deprecated." + " Please use tvm.relay.op.contrib.vitis_ai.partition_for_vitis_ai() instead. " + " Check out https://tvm.apache.org/docs/deploy/vitis_ai.html for documentation. " + ) mod = relay.transform.InferType()(mod) + mod = VitisAIAnnotationPass("vitis_ai", target, params)(mod) + return mod - xgraph = pyxir.frontend.tvm.from_relay(mod, params, postprocessing=None) - xgraph = pyxir.partition(xgraph, targets=[target]) - layers = xgraph.get_layers() - relay_ids = [ - list(np.array(layer.attrs["relay_id"]).flatten()) - for layer in layers - if layer.target == target - ] - relay_ids_flatten = [item for sublist in relay_ids for item in sublist] - mod = VitisAIAnnotationPass("vitis_ai", relay_ids_flatten)(mod) +def partition_for_vitis_ai(mod, params=None, dpu=None, **opts): + """Partition the Relay expression for offloading operators to Vitis AI DPU - return mod + Parameters + ---------- + mod : Module + The module to run passes on. + params : Optional[Dict[str, NDArray]] + Constant input parameters. + dpu : str + The DPU identifier (e.g. DPUCZDX8G-zcu104, DPUCADX8G) + + Returns + ------- + ret : Module + """ + + if dpu is None: + raise ValueError("Please pass Vitis AI DPU identifier to the partitioning function") + + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + + seq = tvm.transform.Sequential( + [ + transform.InferType(), + VitisAIAnnotationPass("vitis_ai", dpu, params), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ] + ) + + return seq(mod) diff --git a/src/relay/backend/contrib/vitis_ai/config_vitis_ai.cc b/src/relay/backend/contrib/vitis_ai/config_vitis_ai.cc index f74b5306c5f4..5426a2dc1e65 100644 --- a/src/relay/backend/contrib/vitis_ai/config_vitis_ai.cc +++ b/src/relay/backend/contrib/vitis_ai/config_vitis_ai.cc @@ -29,6 +29,40 @@ namespace relay { namespace contrib { namespace vitis_ai { +/*! \brief Attributes to store the compiler options for Vitis AI */ +struct VitisAICompilerConfigNode : public tvm::AttrsNode { + String dpu; + String build_dir; + String work_dir; + String export_runtime_module; + String load_runtime_module; + TVM_DECLARE_ATTRS(VitisAICompilerConfigNode, "ext.attrs.VitisAICompilerConfigNode") { + TVM_ATTR_FIELD(dpu).describe("Vitis AI DPU identifier").set_default(""); + TVM_ATTR_FIELD(build_dir) + .describe("Build directory to be used (optional, debug)") + .set_default(""); + TVM_ATTR_FIELD(work_dir) + .describe("Work directory to be used (optional, debug)") + .set_default(""); + TVM_ATTR_FIELD(export_runtime_module) + .describe("Export the Vitis AI runtime module to this file") + .set_default(""); + TVM_ATTR_FIELD(load_runtime_module) + .describe("Load the Vitis AI runtime module to this file") + .set_default(""); + } +}; + +class VitisAICompilerConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VitisAICompilerConfig, Attrs, + VitisAICompilerConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(VitisAICompilerConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.vitis_ai.options", VitisAICompilerConfig); + +// Following config options are here for backward compatibility (deprecated API's) /*! \brief The target Vitis-AI accelerator device */ TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.vitis_ai.options.target", String); /*! \brief (Optional config) The build directory to be used by Vitis-AI */ diff --git a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc index 0e5e2ce4c4fa..f9c1cd82b483 100755 --- a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc +++ b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.cc @@ -50,7 +50,7 @@ VitisAIRuntime::VitisAIRuntime(const std::string& symbol_name, const Array const_names, const std::string& target, + const Array const_names, const std::string& dpu_target, const std::string& build_dir, const std::string& work_dir, const std::string& export_rt_mod_path) : symbol_name_(symbol_name), @@ -62,22 +62,23 @@ VitisAIRuntime::VitisAIRuntime(const std::string& symbol_name, const std::string in_tensor_names_ = xgraph->get_input_names(); out_tensor_names_ = xgraph->get_meta_attr("tvm_out_tensors").get_strings(); - pyxir::partition(xgraph, std::vector{target}, ""); + pyxir::partition(xgraph, std::vector{dpu_target}, ""); pyxir::RunOptionsHolder run_options(new pyxir::runtime::RunOptions()); run_options->on_the_fly_quantization = true; run_options->build_dir = build_dir; + run_options->export_runtime_module_path = export_rt_mod_path_; if (!work_dir.empty()) run_options->work_dir = work_dir; rt_mod_ = - pyxir::build_rt(xgraph, target, in_tensor_names_, out_tensor_names_, "vai", run_options); + pyxir::build_rt(xgraph, dpu_target, in_tensor_names_, out_tensor_names_, "vai", run_options); } Module VitisAIRuntimeCreate(const std::string& name, const std::string& xgraph_str, - const std::string& target, const std::string& build_dir, + const std::string& dpu_target, const std::string& build_dir, const std::string& work_dir, const std::string& export_rt_mod_path) { Array const_vars; - auto exec = make_object(name, xgraph_str, const_vars, target, build_dir, work_dir, - export_rt_mod_path); + auto exec = make_object(name, xgraph_str, const_vars, dpu_target, build_dir, + work_dir, export_rt_mod_path); return Module(exec); } diff --git a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h index 1092bc0ba27b..cad3b5e5a7ff 100755 --- a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h +++ b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h @@ -62,14 +62,14 @@ class VitisAIRuntime : public ModuleNode { * \param symbol_name The name of the function. * \param xgraph_str serialized XGraph representation * \param const_names The names of each constant in the sub-graph. - * \param target The Vitis-AI device target (e.g. DPUCADX8G, DPUCZDX8G). + * \param dpu_target The Vitis-AI DPU target identifier (e.g. DPUCADX8G, DPUCZDX8G-zcu104). * \param build_dir The directory to be used for Vitis-AI build files. * \param work_dir The directory to be used for Vitis-AI work files. * \param export_rt_mod_path The path to the file to be used for exporting the * PyXIR runtime module. */ VitisAIRuntime(const std::string& symbol_name, const std::string& xgraph_str, - const Array const_names, const std::string& target, + const Array const_names, const std::string& dpu_target, const std::string& build_dir, const std::string& work_dir, const std::string& export_runtime_module_path); diff --git a/tests/python/contrib/test_vitis_ai/infrastructure.py b/tests/python/contrib/test_vitis_ai/infrastructure.py index 501ee255c143..bd3d85747105 100644 --- a/tests/python/contrib/test_vitis_ai/infrastructure.py +++ b/tests/python/contrib/test_vitis_ai/infrastructure.py @@ -31,7 +31,7 @@ from tvm import relay from tvm import runtime from tvm.relay import transform -from tvm.relay.op.contrib.vitis_ai import annotation +from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai from tvm.relay.build_module import bind_params_by_name from tvm.contrib.target import vitis_ai from tvm.contrib import graph_executor @@ -84,10 +84,7 @@ def build_module( opt_level=3, config={"relay.ext.vitis_ai.options.target": dpu_target} ): if enable_vitis_ai: - mod["main"] = bind_params_by_name(mod["main"], params) - mod = annotation(mod, params, dpu_target) - mod = transform.MergeCompilerRegions()(mod) - mod = transform.PartitionGraph()(mod) + mod = partition_for_vitis_ai(mod, params, dpu_target) tvm_op_count = get_cpu_op_count(mod) assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format( tvm_op_count, tvm_ops diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 6d17b4e37114..8cd77b8cde4a 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -25,6 +25,7 @@ import tvm from tvm.relay.op.contrib.ethosn import ethosn_available +from tvm.contrib.target.vitis_ai import vitis_ai_available from tvm.driver import tvmc @@ -211,6 +212,28 @@ def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant assert type(dumps) is dict +@pytest.mark.skipif( + not vitis_ai_available(), + reason="--target=vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", +) +def test_compile_tflite_module_with_external_codegen_vitis_ai(tflite_mobilenet_v1_1_quant): + pytest.importorskip("tflite") + + mod, params = tvmc.load(tflite_mobilenet_v1_1_quant) + graph, lib, params, dumps = tvmc.compiler.compile_model( + mod, + params, + target="vitis-ai -dpu=DPUCZDX8G-zcu104 -export_runtime_module=vitis_ai.rtmod, llvm", + dump_code="relay", + ) + + # check for output types + assert type(graph) is str + assert type(lib) is tvm.runtime.module.Module + assert type(params) is dict + assert type(dumps) is dict + + @mock.patch("tvm.relay.build") @mock.patch("tvm.driver.tvmc.composite_target.get_codegen_by_target") @mock.patch("tvm.driver.tvmc.load") @@ -218,7 +241,7 @@ def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant def test_compile_check_configs_composite_target(mock_pc, mock_fe, mock_ct, mock_relay): mock_codegen = {} mock_codegen["config_key"] = "relay.ext.mock.options" - mock_codegen["pass_pipeline"] = lambda *args: None + mock_codegen["pass_pipeline"] = lambda *args, **kwargs: None mock_fe.return_value = (None, None) mock_ct.return_value = mock_codegen diff --git a/tests/python/driver/tvmc/test_composite_target.py b/tests/python/driver/tvmc/test_composite_target.py index cef8b117d989..0a0b45eeb970 100644 --- a/tests/python/driver/tvmc/test_composite_target.py +++ b/tests/python/driver/tvmc/test_composite_target.py @@ -34,6 +34,7 @@ def test_get_codegen_names(): names = tvmc.composite_target.get_codegen_names() assert "ethos-n77" in names + assert "vitis-ai" in names assert len(names) > 0