diff --git a/python/tvm/relay/backend/aot/__init__.py b/python/tvm/relay/backend/aot/__init__.py new file mode 100644 index 000000000000..888b331ae092 --- /dev/null +++ b/python/tvm/relay/backend/aot/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This module defines the Relay ahead-of-time (AoT) compiler, +which translates Relay ASTs into C++ code that calls into +already-compiled operators. These end-to-end compiled +programs can in principle run without a runtime. +""" +from .aot import compile_prog diff --git a/python/tvm/relay/backend/aot/aot.py b/python/tvm/relay/backend/aot/aot.py new file mode 100644 index 000000000000..f66e1c49f797 --- /dev/null +++ b/python/tvm/relay/backend/aot/aot.py @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Defines the entry point into the AoT compiler. +""" +import ctypes +import os +import subprocess +import tempfile +import time + +import tvm +from tvm import relay, get_global_func, register_func +from tvm.relay.function import Function +from tvm.relay.expr import Expr, Let, GlobalVar +from tvm.relay.adt import Constructor +from tvm.relay.expr_functor import ExprFunctor +from tvm.relay.backend import compile_engine +from .little_cpp import (PackedCall, CPPFunction, Invoke, Decl, CPPIf, + CPPTuple, CPPMatch, CPPConstructor, CPPTupleGetItem, + CPPRefCreate, CPPRefRead, CPPRefWrite) +from . import to_source +from .convert import convert + +TVM_PATH = os.environ['TVM_HOME'] + +def must_run_process(args): + proc = subprocess.run(args, check=True) + assert proc.returncode == 0 + +def compile_cpp(source, lib_name, flags=None, lib_path=None): + """ + Compiles the given source into a C++ library + and returns the full path to the compiled library. + """ + if flags is None: + flags = [] + + if lib_path is None: + lib_path = os.curdir + + debug_source_path = os.path.join(lib_path, 'source.cc') + # Write out the file for debugging. + with open(debug_source_path, 'w') as source_file: + source_file.write(source) + + # with tempfile.TmporaryDirectory() as tmpdir: + tmpdir = tempfile.mkdtemp(prefix="relay_aot_compiler") + lib_path = os.path.join(tmpdir, lib_name) + source_path = os.path.join(tmpdir, 'source.cc') + with open(source_path, 'w') as source_file: + source_file.write(source) + + must_run_process(["clang-format", "-i", debug_source_path]) + + system = os.uname()[0] + include_paths = [ + f"-I{TVM_PATH}/3rdparty/dmlc-core/include", + f"-I{TVM_PATH}/3rdparty/dlpack/include", + f"-I{TVM_PATH}/3rdparty/HalideIR/src", + f"-I{TVM_PATH}/include", + f"-L{TVM_PATH}/build" + ] + + if system == 'Darwin': + command = [ + "clang", "-std=c++14", "-shared", "-undefined", "dynamic_lookup", + "-o", lib_path, + source_path, + *include_paths, + "-ltvm" + ] + flags + else: + command = [ + "clang", "-std=c++14", "-shared", "-fPIC", "-o", lib_path, + source_path, + *include_paths, + "-ltvm" + ] + flags + + must_run_process(command) + return lib_path + +def load_lib(name): + return ctypes.CDLL(name, ctypes.RTLD_GLOBAL) + +def is_primitive(expr: relay.Expr): + return (isinstance(expr, relay.Function) + and expr.attrs + and expr.attrs.Primitive.value == 1) + +class AoTCompiler(ExprFunctor): + """ + Takes a Relay program and converts into a Little CPP program + that can in turn be converted into C++ source code. + """ + def __init__(self, mod, tgt) -> None: + super().__init__() + self.mod = mod + self.tgt = tgt + self.engine = compile_engine.get() + self.bindings = [[]] + self.gv_map = {} + + def add_binding(self, var, value): + self.bindings[-1].append((var, value)) + + def optimize(self, expr: Function) -> Function: + opts = tvm.transform.Sequential([ + relay.transform.SimplifyInference(), + relay.transform.FuseOps(), + relay.transform.ToANormalForm()]) + self.mod['main'] = expr + self.mod = opts(self.mod) + ret = self.mod['main'] + return ret + + def mk_primitive_op(self, func: Expr, args, output_type) -> Expr: + cc_key = compile_engine.CCacheKey(func, self.tgt) + func_hash = tvm.ir.structural_hash(func) + name = f"op_{func_hash}" + if not get_global_func(name, allow_missing=True): + jit_func = self.engine.jit(cc_key, self.tgt) + register_func(name, jit_func) + return PackedCall(name, args, [x.checked_type for x in args], output_type) + + def visit_call(self, call: Expr) -> Expr: + if is_primitive(call.op): + return self.mk_primitive_op(call.op, call.args, call.checked_type) + if isinstance(call.op, Constructor): + return CPPConstructor(call.op.tag, [self.visit(arg) for arg in call.args]) + assert call.attrs is None + args = [self.visit(arg) for arg in call.args] + func = self.visit(call.op) + return Invoke(func, args) + + def visit_let(self, let: Expr) -> Expr: + self.bindings.append([]) + + while isinstance(let, Let): + cpp_value = self.visit(let.value) + self.add_binding(let.var, cpp_value) + let = let.body + + bindings = self.bindings.pop() + body = self.visit(let) + + return Decl(bindings, body) + + def visit_var(self, var): + return var + + def visit_global_var(self, gv): + if gv not in self.gv_map: + self.gv_map[gv] = "to be updated" + self.gv_map[gv] = self.visit(self.mod[gv]) + return gv + + def visit_function(self, func): + if is_primitive(func): + body = self.mk_primitive_op(func, func.params, func.ret_type) + return CPPFunction(func.params, body, func.checked_type.ret_type) + return CPPFunction(func.params, self.visit(func.body), func.checked_type.ret_type) + + def visit_constant(self, const): + return const + + def visit_if(self, i): + return CPPIf(self.visit(i.cond), + self.visit(i.true_branch), + self.visit(i.false_branch), + i.checked_type) + + def visit_tuple(self, t): + return CPPTuple([self.visit(f) for f in t.fields], t.checked_type) + + def visit_match(self, m): + return CPPMatch(self.visit(m.data), + [(c.lhs, self.visit(c.rhs)) for c in m.clauses], + m.checked_type) + + def visit_op(self, op): + raise Exception(f'op outside of primitive: {op}') + + def visit_constructor(self, ctor): + raise Exception('Constructors should be handled when visiting calls.') + + def visit_tuple_getitem(self, t): + return CPPTupleGetItem(self.visit(t.tuple_value), t.index, t.checked_type) + + def visit_ref_create(self, r): + return CPPRefCreate(self.visit(r.value), r.checked_type) + + def visit_ref_read(self, r): + return CPPRefRead(self.visit(r.ref), r.checked_type) + + def visit_ref_write(self, r): + return CPPRefWrite(self.visit(r.ref), self.visit(r.value)) + +_LIB_COUNTER = 1 +_LIB = [] + +def lib_and_func_name(name): + global _LIB_COUNTER + packed_name = f'relay.aot.{name}.{_LIB_COUNTER}' + lib_name = f"librelay_aot_{_LIB_COUNTER}.so" + _LIB_COUNTER += 1 + return lib_name, packed_name + +def _mk_wrapper(func, ctx, constants, record_time): + def _wrapper(*args): + new_constants = [convert(a, ctx) for a in constants] + new_args = [convert(a, ctx) for a in args] + begin = time.perf_counter() + res = func(*new_constants, *new_args) + end = time.perf_counter() + return res if not record_time else (res, end - begin) + return _wrapper + +def compile_prog(func, mod, ctx, tgt, name='default', record_time=False): + """Compile a Relay function into a C++ file that + implements a program with the same semantics, + which calls into TVM only for operators. + + Parameters + ---------- + func: Expr + A Relay function to compile + (either a literal Relay function + or a GlobalVar that is in `mod`). + + mod: IRModule + Module containing any functions referenced by `func`. + + ctx: Context + The TVM context. + + tgt: Target + The TVM target. + + name: String + The name of the target binary library. + + record_time: Bool + If True, the return value of the function + will include the program's execution time. + + Returns + ------- + result: Function + A function that, when pass in some values, + will convert them to the right format + and call the compiled func (a PackedFunc). + """ + global _LIB + if isinstance(func, GlobalVar): + func = mod[func] + assert isinstance(func, Function) + compiler = AoTCompiler(mod, tgt) + func = compiler.optimize(func) + func = compiler.visit(func) + lib_name, packed_name = lib_and_func_name(name) + constants, source_code = to_source.to_source(func, compiler.gv_map, ctx, packed_name) + lib_name = f"librelay_aot_{_LIB_COUNTER}.so" + library_path = compile_cpp(source_code, lib_name, flags=["-O3"]) + _LIB.append(load_lib(library_path)) + func = get_global_func(packed_name) + return _mk_wrapper(func, ctx, constants, record_time) diff --git a/python/tvm/relay/backend/aot/convert.py b/python/tvm/relay/backend/aot/convert.py new file mode 100644 index 000000000000..be4da397bb6b --- /dev/null +++ b/python/tvm/relay/backend/aot/convert.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Responsible for converting function arguments into +a form that can be passed to a `PackedFunc`. +""" +import numpy as np +import tvm +from tvm import relay + +def convert(a, ctx): + """ + Converts a function input `a` + (which may take constant defined in Relay, numpy arrays, + or TVM NDArrays) + into a form that can be passed to a TVM `PackedFunc` + with the given context. + """ + # convert(convert(a, tg), tg) = convert(a, tg) + while True: + if isinstance(a, int): + a = np.array(a, dtype='int32') + elif isinstance(a, np.ndarray): + a = tvm.nd.array(a, ctx) + elif isinstance(a, tvm.runtime.NDArray): + return a + elif isinstance(a, relay.Call): + assert isinstance(a.op, relay.Constructor) + a = (a.op, *a.args) + elif isinstance(a, tuple): + assert isinstance(a[0], relay.Constructor) + a = relay.backend.interpreter.ConstructorValue( + a[0].tag, [convert(arg, ctx) for arg in a[1:]], a[0]) + elif isinstance(a, relay.backend.interpreter.ConstructorValue): + return a + else: + raise Exception(a, type(a)) diff --git a/python/tvm/relay/backend/aot/little_cpp.py b/python/tvm/relay/backend/aot/little_cpp.py new file mode 100644 index 000000000000..d72fbd8e4d3e --- /dev/null +++ b/python/tvm/relay/backend/aot/little_cpp.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Defines the Little CPP intermediate representation used by the AoT compiler +(corresponds to a small subset of the C++ AST +and a couple of TVM-specific concepts). +""" +from typing import Any, Optional, List, Tuple +import attr +from tvm.relay import Var + +class LittleCppNode: + pass + +@attr.s(auto_attribs=True) +class Decl(LittleCppNode): + bindings: List[Tuple[Var, LittleCppNode]] + body: LittleCppNode + +@attr.s(auto_attribs=True) +class PackedCall(LittleCppNode): + name: str + args: Any + args_type: Any + ret_type: Any + +@attr.s(auto_attribs=True) +class Invoke(LittleCppNode): + call: Any + args: Any + +@attr.s(auto_attribs=True) +class CPPFunction(LittleCppNode): + params: List[Var] + body: Any + ret_type: Any + name: Optional[str] = None + +@attr.s(auto_attribs=True) +class CPPIf(LittleCppNode): + cond: Any + true_branch: Any + false_branch: Any + relay_type: Any + +@attr.s(auto_attribs=True) +class CPPTuple(LittleCppNode): + fields: List[Any] + relay_type: Any + +@attr.s(auto_attribs=True) +class CPPMatch(LittleCppNode): + data: Any + clause: List[Tuple[Any, Any]] + relay_type: Any + +@attr.s(auto_attribs=True) +class CPPConstructor(LittleCppNode): + tag: int + fields: List[Any] + +@attr.s(auto_attribs=True) +class CPPTupleGetItem(LittleCppNode): + tuple_value: Any + index: int + relay_type: Any + +@attr.s(auto_attribs=True) +class CPPRefCreate(LittleCppNode): + value: Any + relay_type: Any + +@attr.s(auto_attribs=True) +class CPPRefRead(LittleCppNode): + ref: Any + relay_type: Any + +@attr.s(auto_attribs=True) +class CPPRefWrite(LittleCppNode): + ref: Any + value: Any diff --git a/python/tvm/relay/backend/aot/to_source.py b/python/tvm/relay/backend/aot/to_source.py new file mode 100644 index 000000000000..3eb2b1e04193 --- /dev/null +++ b/python/tvm/relay/backend/aot/to_source.py @@ -0,0 +1,564 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Responsible for taking Little CPP ASTs and converting them +into C++ source code. +""" +from tvm import relay +from . import little_cpp + +class ExprWithStmt: + """ + Representation of an expression that requires a + C++ statement to define terms used in it. + """ + def __init__(self, expr, stmt=""): + assert isinstance(expr, str) + assert isinstance(stmt, str) + assert "ExprWithStmt" not in expr + assert "ExprWithStmt" not in stmt + self.expr = expr + self.stmt = stmt + + def __str__(self): + return f"ExprWithStmt({self.expr}, {self.stmt})" + + def __repr__(self): + return self.__str__() + +class ToSource: + """ + Handles converting Little CPP ASTs into C++ source code. + """ + def __init__(self, gv_map): + self.gv_map = gv_map + self.name_counter = 0 + self.source_content = "" + self.name_map = {} + self.local = True + self.declare = "" + self.declare_map = {} + self.input_const = [] + + def fresh_global_name(self): + name = f"global{self.name_counter}" + self.name_counter += 1 + return name + + def sanitize(self, name): + return name.replace("-", "_").replace("/", "_") + + def fresh_local_name(self, var=None): + if var is not None: + name = f"local_{self.sanitize(var.name_hint)}_{self.name_counter}" + else: + name = f"local_{self.name_counter}" + self.name_counter += 1 + return name + + def fresh_label_name(self): + name = f"label_{self.name_counter}" + self.name_counter += 1 + return name + + def visit(self, node, *, local=True, name=None): + """ + Visits a Little CPP node and returns C++ code as text + in the form of statements with the necessary definitions + and an expression corresponding to the result of the node. + + Parameters + ---------- + node: Little CPP node to be compiled + + local: Optional[bool] + For function definitions, specifies whether to + treat it as a local definition (if True) or global (if False) + + name: Optional[str] + Specifies a name for the node if it's a function + definition (otherwise the compiler generates a name) + + Returns + ------- + result: ExprWithStmt + Contains a C++ expression corresponding + to the result of `node` (as a string) + with one or more C++ statements (as strings) + containing needed definitions. + """ + if isinstance(node, little_cpp.PackedCall): + res = self.visit_packed_call(node) + elif isinstance(node, little_cpp.CPPFunction): + res = self.visit_cpp_function(node, local, name) + elif isinstance(node, little_cpp.Decl): + res = self.visit_decl(node) + elif isinstance(node, little_cpp.Invoke): + res = self.visit_invoke(node) + elif isinstance(node, relay.Var): + res = ExprWithStmt(self.name_map[node]) + elif isinstance(node, relay.GlobalVar): + res = self.visit_global_var(node) + elif isinstance(node, relay.Constant): + res = self.visit_constant(node) + elif isinstance(node, little_cpp.CPPIf): + res = self.visit_if(node) + elif isinstance(node, little_cpp.CPPTuple): + res = self.visit_tuple(node) + elif isinstance(node, little_cpp.CPPConstructor): + res = self.visit_constructor(node) + elif isinstance(node, little_cpp.CPPMatch): + res = self.visit_match(node) + elif isinstance(node, little_cpp.CPPTupleGetItem): + res = self.visit_tuple_getitem(node) + elif isinstance(node, little_cpp.CPPRefCreate): + res = self.visit_ref_create(node) + elif isinstance(node, little_cpp.CPPRefRead): + res = self.visit_ref_read(node) + elif isinstance(node, little_cpp.CPPRefWrite): + res = self.visit_ref_write(node) + else: + raise Exception(str(node)) + assert isinstance(res, ExprWithStmt) + return res + + def visit_ref_create(self, node): + value = self.visit(node.value) + return ExprWithStmt(f"RefValue({value.expr})", value.stmt) + + def visit_ref_read(self, node): + ref = self.visit(node.ref) + return ExprWithStmt(f"Downcast({ref.expr})->value", ref.stmt) + + def visit_ref_write(self, node): + ref = self.visit(node.ref) + value = self.visit(node.value) + stmt = ref.stmt + value.stmt + f"Downcast({ref.expr})->value={value.expr};\n" + return ExprWithStmt("runtime::ADT::Tuple()", stmt) + + def visit_tuple_getitem(self, node): + visit_tup = self.visit(node.tuple_value) + return ExprWithStmt(f"Downcast" + f"({visit_tup.expr})" + f"[{node.index}]", visit_tup.stmt) + + def visit_constructor(self, node): + args_str, _ = self.visit_args(node.fields) + return ExprWithStmt(f"TagToCV({node.tag}, {{{args_str}}})") + + def pattern_var(self, pat, var_set): + """ + Given a match pattern `pat` and a set of variable names `var_set`, + adds the variables appearing in `pat` to `var_set` and + raises an exception if any is already in the set + (the names should be distinct). + """ + if isinstance(pat, relay.PatternConstructor): + for x in pat.patterns: + self.pattern_var(x, var_set) + elif isinstance(pat, relay.PatternVar): + assert pat.var not in var_set + var_set.add(pat.var) + else: + raise Exception(str(pat)) + + def visit_match(self, node): + """ + Handle a match expression. + """ + data = self.visit(node.data) + stmt_str = data.stmt + + pattern_var_set = set() + for c in node.clause: + self.pattern_var(c[0], pattern_var_set) + + for v in pattern_var_set: + bind_name = self.fresh_local_name() + self.name_map[v] = bind_name + stmt_str += f"ObjectRef {bind_name};\n" + + # match data_name to pat, and fill the var accordingly. + # go to fail_label or ok_label base on failure/success. + def visit_pattern(pat, data_name, fail_label, ok_label): + if isinstance(pat, relay.PatternConstructor): + data_name = f"Downcast({data_name})" + ok_case = "" + bind_names = [] + assert len(pat.constructor.inputs) == len(pat.patterns) + for i, _ in enumerate(pat.constructor.inputs): + bind_name = self.fresh_local_name() + bind_names.append(bind_name) + ok_case += f"ObjectRef {bind_name} = {data_name}->fields[{i}];\n" + for bind_name, pattern in zip(bind_names, pat.patterns): + next_label = self.fresh_label_name() + ok_case += visit_pattern(pattern, bind_name, fail_label, next_label) + ok_case += f"{next_label}:\n" + ok_case += f"goto {ok_label};" + return f""" + CHECK({data_name}->tag != -1); + if ({data_name}->tag == {pat.constructor.tag}) {{ + {ok_case} + }} else {{ + goto {fail_label}; + }} + """ + + if isinstance(pat, relay.PatternVar): + return f""" + {self.name_map[pat.var]} = {data_name}; + """ + + raise Exception(str(pat)) + + in_name = self.fresh_local_name() + out_name = self.fresh_local_name() + stmt_str += f"ObjectRef {in_name} = {data.expr};\n" + stmt_str += f"ObjectRef {out_name};\n" + match_finish_label = self.fresh_label_name() + for clause in node.clause: + clause_value = self.visit(clause[1]) + fail_label = self.fresh_label_name() + ok_label = self.fresh_label_name() + stmt_str += f"""{{ + {visit_pattern(clause[0], in_name, fail_label, ok_label)} + }} + """ + stmt_str += f"""{{ + {ok_label}: + {clause_value.stmt} + {out_name} = {clause_value.expr}; + goto {match_finish_label}; + }} + """ + stmt_str += f"{fail_label}:\n" + stmt_str += """CHECK(false) << "does not match any";\n""" + stmt_str += f"{match_finish_label}: ;" + return ExprWithStmt(out_name, stmt_str) + + def visit_tuple(self, node): + expr = [] + stmt_str = "" + for field in node.fields: + visit_field = self.visit(field) + expr.append(visit_field.expr) + stmt_str += visit_field.stmt + list_name = self.fresh_local_name() + stmt_str += f"std::vector {list_name} = {{{inter(expr)}}};" + return ExprWithStmt(f"runtime::ADT::Tuple({list_name})", stmt_str) + + def visit_if(self, node): + """ + Handle an if-else expression. + """ + cond = self.visit(node.cond) + true_branch = self.visit(node.true_branch) + false_branch = self.visit(node.false_branch) + ret_name = self.fresh_local_name() + stmt = f"ObjectRef {ret_name};" + stmt += f""" + {cond.stmt} + if (NDToBool(ObjectRefToND({cond.expr}))) {{ + {true_branch.stmt} + {ret_name} = {true_branch.expr}; + }} else {{ + {false_branch.stmt} + {ret_name} = {false_branch.expr}; + }} + """ + return ExprWithStmt(ret_name, stmt) + + def visit_constant(self, const): + if const not in self.declare_map: + name = self.fresh_global_name() + self.declare_map[const] = name + self.declare += f"ObjectRef {name};\n" + self.input_const.append((name, const.data.asnumpy())) + return ExprWithStmt(self.declare_map[const]) + + def visit_global_var(self, global_var): + if global_var not in self.declare_map: + name = self.fresh_global_name() + self.declare_map[global_var] = f"{name}" + visit_gv = self.visit(self.gv_map[global_var], + local=False, name=name) + assert visit_gv.stmt == "" + assert visit_gv.expr == f"{name}" + return ExprWithStmt(self.declare_map[global_var]) + + def visit_args(self, args): + args_str = "" + stmt_str = "" + for i, arg in enumerate(args): + visit_arg = self.visit(arg) + args_str += visit_arg.expr + stmt_str += visit_arg.stmt + if i != len(args) - 1: + args_str += ", " + return args_str, stmt_str + + def visit_invoke(self, invoke): + args_str, stmt_str = self.visit_args(invoke.args) + func = self.visit(invoke.call) + return ExprWithStmt( + f"Apply({func.expr}, std::vector({{{args_str}}}))", + stmt_str + func.stmt) + + def visit_decl(self, decl): + """ + Handles a declaration. + """ + source = "" + for var, value in decl.bindings: + local_name = self.fresh_local_name(var) + self.name_map[var] = local_name + visited_value = self.visit(value, name=local_name) + source += visited_value.stmt + source += f"""ObjectRef {local_name} = {visited_value.expr};""" + body = self.visit(decl.body) + source += body.stmt + return ExprWithStmt(body.expr, source) + + def nd_dtype(self, tensor_type): + """Given a Relay tensor type, returns the appropriate dtype name""" + assert isinstance(tensor_type, relay.ty.TensorType) + if tensor_type.dtype == 'int32': + return 'dtype_i32' + if tensor_type.dtype == 'int8': + return 'dtype_i8' + if tensor_type.dtype == 'float32': + return 'dtype_f32' + if tensor_type.dtype == 'bool': + return 'dtype_u1' + raise Exception("unknown tensor dtype: " + str(tensor_type)) + + def nd_shape(self, tensor_type): + """ + Given a Relay tensor type, returns its shape. + """ + return f"{{{inter([str(s) for s in tensor_type.shape])}}}" + + def visit_packed_call(self, call): + """ + Handle a call to a PackedFunc. + """ + decl_str = "" + args = [] + for arg in call.args: + visit_arg = self.visit(arg) + decl_str += visit_arg.stmt + args.append(visit_arg.expr) + args_str = [] + + def convert_input(input_ty, arg): + if isinstance(input_ty, relay.ty.TensorType): + args_str.append(f"{arg}") + else: + assert isinstance(input_ty, relay.ty.TupleType) + tuple_name = self.fresh_local_name() + nonlocal decl_str + decl_str += (f"runtime::ADT {tuple_name} =" + f" Downcast({arg});\n") + for i, t in enumerate(input_ty.fields): + convert_input(t, f"{tuple_name}[{i}]") + assert len(call.args_type) == len(call.args) + for i in range(len(call.args_type)): + convert_input(call.args_type[i], args[i]) + + def convert_output(output_ty): + nonlocal decl_str + if isinstance(output_ty, relay.ty.TensorType): + tensor_name = self.fresh_local_name() + decl_str += (f"NDArray {tensor_name} = " + f"NDArray::Empty({self.nd_shape(output_ty)}, " + f"{self.nd_dtype(output_ty)}, context);\n") + args_str.append(f"{tensor_name}") + return tensor_name + + assert isinstance(output_ty, relay.ty.TupleType) + list_name = self.fresh_local_name() + list_members = inter([convert_output(t) for t in output_ty.fields]) + decl_str += f"std::vector {list_name} = {{{list_members}}};" + return f"runtime::ADT::Tuple({list_name})" + + out = convert_output(call.ret_type) + return ExprWithStmt(out, f""" + {decl_str} + const PackedFunc *pf = runtime::Registry::Get("{call.name}"); + CHECK(pf); + (*pf)({inter(args_str)}); + """) + + def visit_cpp_function(self, func, local, name): + """ + Handle a Little CPP function. + """ + vec = self.fresh_local_name() + body = "" + + for i, param in enumerate(func.params): + pname = self.fresh_local_name(param) + self.name_map[param] = pname + body += f"ObjectRef {pname} = {vec}.at({i});\n" + + body += f"ObjectRef {name} = self;\n" + visit_body = self.visit(func.body) + body = body + visit_body.stmt + f"""return {visit_body.expr};""" + expr = f""" + FunctionValueNode::make([=]( + const std::vector& {vec}, + const ObjectRef& self) {{ + {body} + }}); + """ + + if local: + return ExprWithStmt(expr) + + if name is None: + name = self.fresh_global_name() + + self.declare += f""" + static ObjectRef {name}_func() {{ + static ObjectRef ret = {expr}; + return ret; + }} + ObjectRef {name} = {name}_func(); + """ + return ExprWithStmt(f"{name}") + + def mk_register_api(self, name: str, func) -> str: + """ + Converts the given Little CPP function into C++ text + and registers the produced C++ function as a TVM `PackedFunc` under the given name. + """ + visited_func = self.visit(func, local=False) + assert visited_func.stmt == "" + source = self.declare + + args = "" + if isinstance(func, relay.GlobalVar): + func = self.gv_map[func] + end = len(func.params) - 1 + init = "" + for i, (input_name, _) in enumerate(self.input_const): + init += f"{input_name} = args[{i}];\n" + for i in range(len(func.params)): + args += f"args[{i+len(self.input_const)}]" + if i != end: + args += ", " + + source += f""" + TVM_REGISTER_GLOBAL("{name}") + .set_body([](TVMArgs args, TVMRetValue* ret) {{ + {init} + std::initializer_list ilist = {{{args}}}; + *ret = Apply({visited_func.expr}, std::vector(ilist)); + }}); + """ + return source + +def inter(strs, sep=", "): + ret = "" + for i, string in enumerate(strs): + ret += string + if i != len(strs) - 1: + ret += sep + return ret + +def mk_file(body, ctx): + return f""" + #include + #include + #include + #include + #include + #include + + using namespace tvm; + using namespace runtime; + using namespace relay; + + static DLDataType dtype_f32 = DLDataType {{ .code = DLDataTypeCode::kDLFloat, .bits = 32, .lanes = 1 }}; + static DLDataType dtype_u32 = DLDataType {{ .code = DLDataTypeCode::kDLUInt, .bits = 32, .lanes = 1 }}; + static DLDataType dtype_u1 = DLDataType {{ .code = DLDataTypeCode::kDLUInt, .bits = 1, .lanes = 1 }}; + static DLDataType dtype_i32 = DLDataType {{ .code = DLDataTypeCode::kDLInt, .bits = 32, .lanes = 1 }}; + static DLDataType dtype_i8 = DLDataType {{ .code = DLDataTypeCode::kDLInt, .bits = 8, .lanes = 1 }}; + static DLContext context = DLContext {{ .device_type = DLDeviceType({ctx.device_type}), .device_id = {ctx.device_id} }}; + + static bool NDToBool(const NDArray& nd) {{ + DLContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + NDArray cpu_array = nd.CopyTo(cpu_ctx); + CHECK_EQ(DataType(cpu_array->dtype), DataType::Bool()); + return reinterpret_cast(cpu_array->data)[0]; + }} + + static NDArray ObjectRefToND(const ObjectRef& v) {{ + return Downcast(v); + }} + + static ConstructorValue TagToCV(size_t tag, const tvm::Array& fields) {{ + ObjectPtr n = make_object(); + ObjectPtr con = make_object(); + con->tag = tag; + n->tag = tag; + n->constructor = Constructor(con); + n->fields = fields; + return ConstructorValue(n); + }} + + /*! \\brief A Function value. */ + class FunctionValue; + + using function_value_t = std::function&, const ObjectRef&)>; + struct FunctionValueNode : Object {{ + function_value_t f; + + FunctionValueNode() {{ }} + + void VisitAttrs(tvm::AttrVisitor* v) {{ }} + + TVM_DLL static FunctionValue make(const function_value_t& f); + + static constexpr const char* _type_key = "relay.FunctionValue"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionValueNode, Object); + }}; + + class FunctionValue : public ObjectRef {{ + public: + TVM_DEFINE_OBJECT_REF_METHODS(FunctionValue, ObjectRef, FunctionValueNode); + }}; + + FunctionValue FunctionValueNode::make(const function_value_t& f) {{ + ObjectPtr n = make_object(); + n->f = f; + return FunctionValue(n); + }} + + ObjectRef Apply(const ObjectRef& op, const std::vector& args) {{ + return Downcast(op)->f(args, op); + }} + + {body} + """ + +def to_source(program, gv_map, ctx, name) -> str: + convert = ToSource(gv_map) + ret = mk_file(convert.mk_register_api(name, program), ctx) + return [value for name, value in convert.input_const], ret diff --git a/tests/python/relay/test_aot.py b/tests/python/relay/test_aot.py new file mode 100644 index 000000000000..e8021a52f0ce --- /dev/null +++ b/tests/python/relay/test_aot.py @@ -0,0 +1,305 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np + +import tvm +from tvm import relay +from tvm import IRModule as Module +from tvm.relay import var, Function, op, GlobalVar, TypeVar, FuncType +from tvm.relay.prelude import Prelude +from tvm.relay.testing import add_nat_definitions +from tvm.relay.backend import aot + + +def compile(f, mod): + tgt = tvm.target.create('llvm') + ctx = tvm.context('llvm', 0) + return aot.compile_prog(f, mod, ctx=ctx, tgt=tgt) + + +def test_identity(): + mod = Module() + x = var('x', shape=()) + func = Function([x], x) + cfunc = compile(func, mod) + a = tvm.nd.array(np.array(1.0, dtype='float32')) + output = cfunc(a) + np.testing.assert_allclose(output.asnumpy(), a.asnumpy()) + + +def test_add(): + mod = Module() + x = var('x', shape=()) + y = var('y', shape=()) + z = x + y + func = Function([x, y], z) + cfunc = compile(func, mod) + a = tvm.nd.array(np.array(1.0, dtype='float32')) + b = tvm.nd.array(np.array(1.0, dtype='float32')) + c = tvm.nd.array(np.array(2.0, dtype='float32')) + output = cfunc(a, b) + np.testing.assert_allclose(output.asnumpy(), c.asnumpy()) + + +def test_mult_op(): + mod = Module() + x = var('x', shape=()) + y = var('y', shape=()) + z = x + y + zz = op.exp(z) + func = Function([x, y], zz) + cfunc = compile(func, mod) + a = tvm.nd.array(np.array(1.0, dtype='float32')) + b = tvm.nd.array(np.array(1.0, dtype='float32')) + output = cfunc(a, b) + np.testing.assert_allclose(output.asnumpy(), np.exp(a.asnumpy() + b.asnumpy())) + + +def test_double(): + mod = Module() + x = var('x', shape=()) + double = GlobalVar('double') + mod[double] = Function([x], x + x) + x = var('x', shape=()) + cfunc = compile(Function([x], double(double(x))), mod) + a = tvm.nd.array(np.array(1.5, dtype='float32')) + output = cfunc(a) + np.testing.assert_allclose(output.asnumpy(), np.array(6.0, dtype='float32')) + + +def test_42(): + mod = Module() + func = Function([], relay.const(42)) + cfunc = compile(func, mod) + output = cfunc() + np.testing.assert_allclose(output.asnumpy(), np.array(42.0, dtype='float32')) + + +def test_add_42(): + mod = Module() + x = var('x', shape=()) + func = Function([x], x + relay.const(42.0)) + cfunc = compile(func, mod) + a = tvm.nd.array(np.array(42.0, dtype='float32')) + output = cfunc(a) + np.testing.assert_allclose(output.asnumpy(), np.array(84.0, dtype='float32')) + + +def test_int_mult_3(): + mod = Module() + x = var('x', dtype='int32', shape=()) + func = Function([x], x * relay.const(3)) + cfunc = compile(func, mod) + a = tvm.nd.array(np.array(4, dtype='int32')) + output = cfunc(a) + np.testing.assert_allclose(output.asnumpy(), np.array(12, dtype='int32')) + + +def test_abs(): + mod = Module() + x = var('x', shape=()) + func = Function([x], relay.If(op.less(x, relay.const(0.0)), relay.const(-1.0) * x, x)) + cfunc = compile(func, mod) + a = tvm.nd.array(np.array(12.0, dtype='float32')) + output = cfunc(a) + np.testing.assert_allclose(output.asnumpy(), np.array(12.0, dtype='float32')) + a = tvm.nd.array(np.array(-34.0, dtype='float32')) + output = cfunc(a) + np.testing.assert_allclose(output.asnumpy(), np.array(34.0, dtype='float32')) + + +def test_recur_sum_global(): + mod = Module() + x = var('x', dtype='int32', shape=()) + sum = GlobalVar('sum') + c = relay.const(0) + mod[sum] = Function([x], + relay.If(op.less(x, c), c, x + sum(x - relay.const(1))), + relay.TensorType(dtype='int32', shape=())) + cfunc = compile(Function([], sum(relay.const(10))), mod) + output = cfunc() + np.testing.assert_allclose(output.asnumpy(), np.array(55, dtype='int32')) + + +def nat_to_int(n): + if n.constructor.tag & 0xff == 1: + return 1 + nat_to_int(n.fields[0]) + else: + assert n.constructor.tag & 0xff == 0 + return 0 + + +def int_to_nat(p, i): + if i > 0: + return p.s(int_to_nat(p, i - 1)) + else: + assert i == 0 + return p.z() + + +def test_nat_3(): + mod = Module() + p = Prelude(mod) + add_nat_definitions(p) + cfunc = compile(Function([], p.s(p.s(p.s(p.z())))), mod) + output = cfunc() + assert nat_to_int(output) == 3 + + +def test_nat_add(): + mod = Module() + p = Prelude(mod) + add_nat_definitions(p) + cfunc = compile(Function([], p.add(p.s(p.s(p.s(p.z()))), p.s(p.s(p.s(p.s(p.z())))))), mod) + output = cfunc() + assert nat_to_int(output) == 7 + + +def test_add_convert(): + mod = Module() + p = Prelude(mod) + add_nat_definitions(p) + cfunc = compile(p.add, mod) + output = cfunc(int_to_nat(p, 12), int_to_nat(p, 34)) + assert nat_to_int(output) == 46 + + +def test_ref(): + mod = Module() + three_with_ref = relay.GlobalVar('three_with_ref') + i = relay.Var('i') + iv = relay.Var('iv') + u = relay.Var('u') + uv = relay.Var('uv') + body = relay.add(iv, uv) + body = relay.Let(uv, relay.RefRead(i), body) + body = relay.Let(u, relay.RefWrite(i, relay.const(2, dtype='int32')), body) + body = relay.Let(iv, relay.RefRead(i), body) + body = relay.Let(i, relay.RefCreate(relay.const(1, dtype='int32')), body) + mod[three_with_ref] = relay.Function([], body) + cfunc = compile(three_with_ref, mod) + output = cfunc() + np.testing.assert_allclose(output.asnumpy(), np.array(3, dtype='int32')) + + +def test_tuple(): + mod = Module() + cfunc = compile(Function([], + relay.TupleGetItem(relay.Tuple([relay.const(3, dtype='int32'), + relay.const(4.0, dtype='float32')]), + 1)), + mod) + np.testing.assert_allclose(cfunc().asnumpy(), np.array(4.0, dtype='float32')) + + +def test_get_valid_counts(): + # Based on test_get_valid_counts in test_op_level5. + # Tests the case of a packed func returning a Relay tuple. + # Only checks the shapes of the output because the reference implementation + # is long and inconvenient. + shape = (1, 2500, 6) + score_threshold = 0 + id_index = 0 + score_index = 1 + np_data = np.random.uniform(low=-2, high=2, size=shape).astype("float32") + mod = Module() + cfunc = compile( + relay.Function( + [], + relay.vision.get_valid_counts( + relay.const(np_data), score_threshold, id_index, score_index + ).astuple()), + mod) + + relay_out = cfunc() + out1 = relay_out[0].asnumpy() + out2 = relay_out[1].asnumpy() + assert out1.shape == (shape[0],) + assert out2.shape == shape + + +def test_compose(): + mod = Module() + p = Prelude(mod) + add_nat_definitions(p) + x = relay.Var('x') + inc = GlobalVar('inc') + mod[inc] = Function([x], p.s(x)) + x = relay.Var('x') + func = GlobalVar('func') + f = Function([x], relay.Call(p.compose(inc, p.double), [x])) + mod[func] = f + cfunc = compile(func, mod) + assert nat_to_int(cfunc(p.s(p.s(p.z())))) == 5 + + +def test_recur_sum_local(): + mod = Module() + x = var('x', dtype='int32', shape=()) + t = relay.TensorType(dtype='int32', shape=()) + sum = relay.Var('sum', type_annotation=relay.FuncType([t], t)) + c = relay.const(0) + func = Function([x], + relay.If(op.less(x, c), c, x + sum(x - relay.const(1))), + t) + body = relay.Let(sum, func, sum(relay.const(10))) + cfunc = compile(Function([], body), mod) + output = cfunc() + np.testing.assert_allclose(output.asnumpy(), np.array(55, dtype='int32')) + + +def test_local_local_rec_outer_scope(): + mod = Module() + x = var('x', dtype='int32', shape=()) + t = relay.TensorType(dtype='int32', shape=()) + sum = relay.Var('sum', type_annotation=relay.FuncType([t], t)) + c = relay.const(0) + + # we define a locally recursive function inside another function's scope + # and have that function return the closure of the locally recursive function + inner_func = Function([x], + relay.If(op.less(x, c), c, x + sum(x - relay.const(1))), + t) + outer_func_body = relay.Let(sum, inner_func, sum) + outer_func = Function([], outer_func_body) + f = relay.Var('f') + body = relay.Let(f, outer_func(), f(relay.const(10))) + cfunc = compile(Function([], body), mod) + output = cfunc() + np.testing.assert_allclose(output.asnumpy(), np.array(55, dtype='int32')) + + +if __name__ == "__main__": + test_identity() + test_add() + test_mult_op() + test_double() + test_42() + test_add_42() + test_int_mult_3() + test_abs() + test_recur_sum_global() + test_nat_3() + test_nat_add() + test_add_convert() + test_ref() + test_tuple() + test_get_valid_counts() + test_compose() + test_recur_sum_local() + test_local_local_rec_outer_scope()