From a0a0e69904b55c8a12d93357d221e3106664db2b Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 5 Aug 2020 19:09:56 -0700 Subject: [PATCH 1/4] Add prototype Relay AoT compiler directly into TVM --- python/tvm/runtime/aot/__init__.py | 1 + python/tvm/runtime/aot/aot.py | 261 +++++++++++++++ python/tvm/runtime/aot/convert.py | 23 ++ python/tvm/runtime/aot/little_cpp.py | 74 +++++ python/tvm/runtime/aot/to_source.py | 471 +++++++++++++++++++++++++++ tests/python/relay/test_aot.py | 287 ++++++++++++++++ 6 files changed, 1117 insertions(+) create mode 100644 python/tvm/runtime/aot/__init__.py create mode 100644 python/tvm/runtime/aot/aot.py create mode 100644 python/tvm/runtime/aot/convert.py create mode 100644 python/tvm/runtime/aot/little_cpp.py create mode 100644 python/tvm/runtime/aot/to_source.py create mode 100644 tests/python/relay/test_aot.py diff --git a/python/tvm/runtime/aot/__init__.py b/python/tvm/runtime/aot/__init__.py new file mode 100644 index 000000000000..b02a1c672b81 --- /dev/null +++ b/python/tvm/runtime/aot/__init__.py @@ -0,0 +1 @@ +from .aot import compile diff --git a/python/tvm/runtime/aot/aot.py b/python/tvm/runtime/aot/aot.py new file mode 100644 index 000000000000..d1a9ff2cc963 --- /dev/null +++ b/python/tvm/runtime/aot/aot.py @@ -0,0 +1,261 @@ +import ctypes +import os +import subprocess +import tempfile +import tvm +from tvm import relay, get_global_func, target, register_func +from tvm.relay.function import Function +from tvm.relay.expr import Expr, Let, GlobalVar +from tvm.relay.adt import Constructor +from tvm.relay.expr_functor import ExprFunctor, ExprVisitor +from tvm.relay.backend import compile_engine +from .little_cpp import PackedCall, CPPFunction, Invoke, Decl, CPPIf, CPPTuple, CPPMatch, CPPConstructor, CPPTupleGetItem +from .little_cpp import CPPRefCreate, CPPRefRead, CPPRefWrite +from . import to_source +from .convert import convert + +TVM_PATH = os.environ['TVM_HOME'] + +def must_run_process(args): + proc = subprocess.run(args) + assert proc.returncode == 0 + +def compile_cpp(source, lib_name, flags=None, lib_path=None): + if flags is None: + flags = [] + + if lib_path is None: + lib_path = os.curdir + + debug_source_path = os.path.join(lib_path, 'source.cc') + # Write out the file for debugging. + with open(debug_source_path, 'w') as source_file: + source_file.write(source) + + # with tempfile.TmporaryDirectory() as tmpdir: + tmpdir = tempfile.mkdtemp(prefix="relay_aot_compiler") + lib_path = os.path.join(tmpdir, lib_name) + source_path = os.path.join(tmpdir, 'source.cc') + with open(source_path, 'w') as source_file: + source_file.write(source) + + must_run_process(["clang-format", "-i", debug_source_path]) + + system = os.uname()[0] + if system == 'Darwin': + command = [ + "clang", + "-std=c++14", + "-shared", + "-undefined", + "dynamic_lookup", + "-o", + lib_path, + source_path, + f"-I{TVM_PATH}/3rdparty/dmlc-core/include", + f"-I{TVM_PATH}/3rdparty/dlpack/include", + f"-I{TVM_PATH}/3rdparty/HalideIR/src", + f"-I{TVM_PATH}/include", + f"-L{TVM_PATH}/build", + "-ltvm" + ] + flags + else: + command = [ + "clang", + "-std=c++14", + "-shared", + "-fPIC", + "-o", + lib_path, + source_path, + f"-I{TVM_PATH}/3rdparty/dmlc-core/include", + f"-I{TVM_PATH}/3rdparty/dlpack/include", + f"-I{TVM_PATH}/3rdparty/HalideIR/src", + f"-I{TVM_PATH}/include", + f"-L{TVM_PATH}/build", + "-ltvm" + ] + flags + + must_run_process(command) + return lib_path + +def load_lib(name): + return ctypes.CDLL(name, ctypes.RTLD_GLOBAL) + +def is_primitive(e: relay.Expr): + return isinstance(e, relay.Function) and e.attrs and e.attrs.Primitive.value == 1 + +class AoTCompiler(ExprFunctor): + def __init__(self, mod, tgt) -> None: + super().__init__() + self.mod = mod + self.tgt = tgt + self.engine = compile_engine.get() + self.bindings = [[]] + self.gv_map = {} + + def add_binding(self, var, value): + self.bindings[-1].append((var, value)) + + def optimize(self, expr: Function) -> Function: + opts = tvm.transform.Sequential([ + relay.transform.SimplifyInference(), + relay.transform.FuseOps(), + relay.transform.ToANormalForm()]) + self.mod['main'] = expr + self.mod = opts(self.mod) + ret = self.mod['main'] + return ret + + def mk_primitive_op(self, func: Expr, args, output_type) -> Expr: + cc_key = compile_engine.CCacheKey(func, self.tgt) + hash = tvm.ir.structural_hash(func) + name = f"op_{hash}" + if not get_global_func(name, allow_missing=True): + jit_func = self.engine.jit(cc_key, self.tgt) + register_func(name, jit_func) + return PackedCall(name, args, [x.checked_type for x in args], output_type) + + def visit_call(self, call: Expr) -> Expr: + if is_primitive(call.op): + return self.mk_primitive_op(call.op, call.args, call.checked_type) + elif isinstance(call.op, Constructor): + return CPPConstructor(call.op.tag, [self.visit(arg) for arg in call.args]) + else: + assert(call.attrs == None) + args = [self.visit(arg) for arg in call.args] + fn = self.visit(call.op) + return Invoke(fn, args) + + def visit_let(self, let: Expr) -> Expr: + self.bindings.append([]) + + while isinstance(let, Let): + cpp_value = self.visit(let.value) + self.add_binding(let.var, cpp_value) + let = let.body + + bindings = self.bindings.pop() + body = self.visit(let) + + return Decl(bindings, body) + + def visit_var(self, var): + return var + + def visit_global_var(self, gv): + if gv not in self.gv_map: + self.gv_map[gv] = "to be updated" + self.gv_map[gv] = self.visit(self.mod[gv]) + return gv + + def visit_function(self, func): + if is_primitive(func): + body = self.mk_primitive_op(func, func.params, func.ret_type) + return CPPFunction(func.params, body, func.checked_type.ret_type) + else: + return CPPFunction(func.params, self.visit(func.body), func.checked_type.ret_type) + + def visit_constant(self, const): + return const + + def visit_if(self, i): + return CPPIf(self.visit(i.cond), + self.visit(i.true_branch), + self.visit(i.false_branch), + i.checked_type) + + def visit_tuple(self, t): + return CPPTuple([self.visit(f) for f in t.fields], t.checked_type) + + def visit_match(self, m): + return CPPMatch(self.visit(m.data), + [(c.lhs, self.visit(c.rhs)) for c in m.clauses], + m.checked_type) + + def visit_op(self, op): + raise Exception(f'op outside of primitive: {op}') + + def visit_tuple_getitem(self, t): + return CPPTupleGetItem(self.visit(t.tuple_value), t.index, t.checked_type) + + def visit_ref_create(self, r): + return CPPRefCreate(self.visit(r.value), r.checked_type) + + def visit_ref_read(self, r): + return CPPRefRead(self.visit(r.ref), r.checked_type) + + def visit_ref_write(self, r): + return CPPRefWrite(self.visit(r.ref), self.visit(r.value)) + +_LIB_COUNTER = 1 +_LIB = [] + +def lib_and_func_name(name): + global _LIB_COUNTER + packed_name = f'relay.aot.{name}.{_LIB_COUNTER}' + lib_name = f"librelay_aot_{_LIB_COUNTER}.so" + _LIB_COUNTER += 1 + return lib_name, packed_name + +import time + +def _mk_wrapper(fn, ctx, constants, record_time): + def _wrapper(*args): + new_constants = [convert(a, ctx) for a in constants] + new_args = [convert(a, ctx) for a in args] + begin = time.perf_counter() + res = fn(*new_constants, *new_args) + end = time.perf_counter() + return res if not record_time else (res, end - begin) + return _wrapper + +def compile(func, mod, ctx, tgt, name='default', record_time=False): + """Compile a Relay function into a C++ file that + implements a program with the same semantics, + which calls into TVM only for operators. + + Parameters + ---------- + func: Expr + A Relay function to compile + (either a literal Relay function + or a GlobalVar that is in `mod`). + + mod: IRModule + Module containing any functions referenced by `func`. + + ctx: Context + The TVM context. + + tgt: Target + The TVM target. + + name: String + The name of the target binary library. + + record_time: Bool + If True, the return value of the function + will include the program's execution time. + + Returns + ------- + result: Function + A function that, when pass in some values, + will convert them to the right format + and call the compiled func (a PackedFunc). + """ + global _LIB + if isinstance(func, GlobalVar): + func = mod[func] + assert isinstance(func, Function) + compiler = AoTCompiler(mod, tgt) + func = compiler.optimize(func) + func = compiler.visit(func) + lib_name, packed_name = lib_and_func_name(name) + constants, source_code = to_source.to_source(mod, func, compiler.gv_map, ctx, packed_name) + lib_name = f"librelay_aot_{_LIB_COUNTER}.so" + library_path = compile_cpp(source_code, lib_name, flags=["-O3"]) + _LIB.append(load_lib(library_path)) + fn = get_global_func(packed_name) + return _mk_wrapper(fn, ctx, constants, record_time) diff --git a/python/tvm/runtime/aot/convert.py b/python/tvm/runtime/aot/convert.py new file mode 100644 index 000000000000..5daaa83b66a8 --- /dev/null +++ b/python/tvm/runtime/aot/convert.py @@ -0,0 +1,23 @@ +import numpy as np +import tvm +from tvm import relay + +# convert(convert(a, tg), tg) = convert(a, tg) +def convert(a, ctx): + while True: + if isinstance(a, int): + a = np.array(a, dtype='int32') + elif isinstance(a, np.ndarray): + a = tvm.nd.array(a, ctx) + elif isinstance(a, tvm.runtime.NDArray): + return a + elif isinstance(a, relay.Call): + assert isinstance(a.op, relay.Constructor) + a = (a.op, *a.args) + elif isinstance(a, tuple): + assert isinstance(a[0], relay.Constructor) + a = relay.backend.interpreter.ConstructorValue(a[0].tag, [convert(arg, ctx) for arg in a[1:]], a[0]) + elif isinstance(a, relay.backend.interpreter.ConstructorValue): + return a + else: + raise Exception(a, type(a)) diff --git a/python/tvm/runtime/aot/little_cpp.py b/python/tvm/runtime/aot/little_cpp.py new file mode 100644 index 000000000000..f3b819d95672 --- /dev/null +++ b/python/tvm/runtime/aot/little_cpp.py @@ -0,0 +1,74 @@ +from tvm.relay import Var, TypeVar +from typing import Any, Optional, List, Tuple +import attr + +class LittleCppNode: + pass + +@attr.s(auto_attribs=True) +class Decl(LittleCppNode): + bindings: List[Tuple[Var, LittleCppNode]] + body: LittleCppNode + +@attr.s(auto_attribs=True) +class PackedCall(LittleCppNode): + name: str + args: Any + args_type: Any + ret_type: Any + +@attr.s(auto_attribs=True) +class Invoke(LittleCppNode): + call: Any + args: Any + +@attr.s(auto_attribs=True) +class CPPFunction(LittleCppNode): + params: List[Var] + body: Any + ret_type: Any + name: Optional[str] = None + +@attr.s(auto_attribs=True) +class CPPIf(LittleCppNode): + cond: Any + true_branch: Any + false_branch: Any + relay_type: Any + +@attr.s(auto_attribs=True) +class CPPTuple(LittleCppNode): + fields: List[Any] + relay_type: Any + +@attr.s(auto_attribs=True) +class CPPMatch(LittleCppNode): + data: Any + clause: List[Tuple[Any, Any]] + relay_type: Any + +@attr.s(auto_attribs=True) +class CPPConstructor(LittleCppNode): + tag: int + fields: List[Any] + +@attr.s(auto_attribs=True) +class CPPTupleGetItem(LittleCppNode): + tuple_value: Any + index: int + relay_type: Any + +@attr.s(auto_attribs=True) +class CPPRefCreate(LittleCppNode): + value: Any + relay_type: Any + +@attr.s(auto_attribs=True) +class CPPRefRead(LittleCppNode): + ref: Any + relay_type: Any + +@attr.s(auto_attribs=True) +class CPPRefWrite(LittleCppNode): + ref: Any + value: Any diff --git a/python/tvm/runtime/aot/to_source.py b/python/tvm/runtime/aot/to_source.py new file mode 100644 index 000000000000..a5aa373d9b4a --- /dev/null +++ b/python/tvm/runtime/aot/to_source.py @@ -0,0 +1,471 @@ +from . import little_cpp +from tvm import relay +from tvm.relay.prelude import Prelude + +class ExprWithStmt: + def __init__(self, expr, stmt=""): + assert isinstance(expr, str) + assert isinstance(stmt, str) + assert "ExprWithStmt" not in expr + assert "ExprWithStmt" not in stmt + self.expr = expr + self.stmt = stmt + + def __str__(self): + return f"ExprWithStmt({self.expr}, {self.stmt})" + + def __repr__(self): + return self.__str__() + +class ToSource: + def __init__(self, gv_map): + self.gv_map = gv_map + self.name_counter = 0 + self.source_content = "" + self.name_map = {} + self.local = True + self.declare = "" + self.declare_map = {} + self.input_const = [] + + def fresh_global_name(self): + name = f"global{self.name_counter}" + self.name_counter += 1 + return name + + def sanitize(self, str): + return str.replace("-", "_").replace("/", "_") + + def fresh_local_name(self, var=None): + if var is not None: + name = f"local_{self.sanitize(var.name_hint)}_{self.name_counter}" + else: + name = f"local_{self.name_counter}" + self.name_counter += 1 + return name + + def fresh_label_name(self): + name = f"label_{self.name_counter}" + self.name_counter += 1 + return name + + # return (str, str) with lhs being stmts, and rhs being expression + def visit(self, node, *, local=True, name=None): + if isinstance(node, little_cpp.PackedCall): + res = self.visit_packed_call(node) + elif isinstance(node, little_cpp.CPPFunction): + res = self.visit_cpp_function(node, local, name) + elif isinstance(node, little_cpp.Decl): + res = self.visit_decl(node) + elif isinstance(node, little_cpp.Invoke): + res = self.visit_invoke(node) + elif isinstance(node, relay.Var): + res = ExprWithStmt(self.name_map[node]) + elif isinstance(node, relay.GlobalVar): + res = self.visit_global_var(node) + elif isinstance(node, relay.Constant): + res = self.visit_constant(node) + elif isinstance(node, little_cpp.CPPIf): + res = self.visit_if(node) + elif isinstance(node, little_cpp.CPPTuple): + res = self.visit_tuple(node) + elif isinstance(node, little_cpp.CPPConstructor): + res = self.visit_constructor(node) + elif isinstance(node, little_cpp.CPPMatch): + res = self.visit_match(node) + elif isinstance(node, little_cpp.CPPTupleGetItem): + res = self.visit_tuple_getitem(node) + elif isinstance(node, little_cpp.CPPRefCreate): + res = self.visit_ref_create(node) + elif isinstance(node, little_cpp.CPPRefRead): + res = self.visit_ref_read(node) + elif isinstance(node, little_cpp.CPPRefWrite): + res = self.visit_ref_write(node) + else: + raise Exception(str(node)) + assert isinstance(res, ExprWithStmt) + return res + + def visit_ref_create(self, node): + vv = self.visit(node.value) + return ExprWithStmt(f"RefValue({vv.expr})", vv.stmt) + + def visit_ref_read(self, node): + vr = self.visit(node.ref) + return ExprWithStmt(f"Downcast({vr.expr})->value", vr.stmt) + + def visit_ref_write(self, node): + vr = self.visit(node.ref) + vv = self.visit(node.value) + stmt = vr.stmt + vv.stmt + f"Downcast({vr.expr})->value={vv.expr};\n" + return ExprWithStmt("runtime::ADT::Tuple()", stmt) + + def visit_tuple_getitem(self, node): + vt = self.visit(node.tuple_value) + return ExprWithStmt(f"Downcast({vt.expr})[{node.index}]", vt.stmt) + + def visit_constructor(self, node): + args_str, stmt_str = self.visit_args(node.fields) + return ExprWithStmt(f"TagToCV({node.tag}, {{{args_str}}})") + + def pattern_var(self, pat, var_set): + if isinstance(pat, relay.PatternConstructor): + for x in pat.patterns: + self.pattern_var(x, var_set) + elif isinstance(pat, relay.PatternVar): + assert pat.var not in var_set + var_set.add(pat.var) + else: + raise Exception(str(pat)) + + def visit_match(self, node): + vd = self.visit(node.data) + stmt_str = vd.stmt + + pattern_var_set = set() + for c in node.clause: + self.pattern_var(c[0], pattern_var_set) + + for v in pattern_var_set: + bind_name = self.fresh_local_name() + self.name_map[v] = bind_name + stmt_str += f"ObjectRef {bind_name};\n" + + # match data_name to pat, and fill the var accordingly. + # go to fail_label or ok_label base on failure/success. + def visit_pattern(pat, data_name, fail_label, ok_label): + if isinstance(pat, relay.PatternConstructor): + data_name = f"Downcast({data_name})" + ok_case = "" + bind_names = [] + assert len(pat.constructor.inputs) == len(pat.patterns) + for i, input_type in enumerate(pat.constructor.inputs): + bind_name = self.fresh_local_name() + bind_names.append(bind_name) + ok_case += f"ObjectRef {bind_name} = {data_name}->fields[{i}];\n" + for bind_name, p in zip(bind_names, pat.patterns): + next_label = self.fresh_label_name() + ok_case += visit_pattern(p, bind_name, fail_label, next_label) + ok_case += f"{next_label}:\n" + ok_case += f"goto {ok_label};" + return f""" + CHECK({data_name}->tag != -1); + if ({data_name}->tag == {pat.constructor.tag}) {{ + {ok_case} + }} else {{ + goto {fail_label}; + }} + """ + elif isinstance(pat, relay.PatternVar): + return f""" + {self.name_map[pat.var]} = {data_name}; + """ + else: + raise Exception(str(pat)) + + in_name = self.fresh_local_name() + out_name = self.fresh_local_name() + stmt_str += f"ObjectRef {in_name} = {vd.expr};\n" + stmt_str += f"ObjectRef {out_name};\n" + match_finish_label = self.fresh_label_name() + for c in node.clause: + vc = self.visit(c[1]) + fail_label = self.fresh_label_name() + ok_label = self.fresh_label_name() + stmt_str += f"""{{ + {visit_pattern(c[0], in_name, fail_label, ok_label)} + }} + """ + stmt_str += f"""{{ + {ok_label}: + {vc.stmt} + {out_name} = {vc.expr}; + goto {match_finish_label}; + }} + """ + stmt_str += f"{fail_label}:\n" + stmt_str += """CHECK(false) << "does not match any";\n""" + stmt_str += f"{match_finish_label}: ;" + return ExprWithStmt(out_name, stmt_str) + + def visit_tuple(self, node): + expr = [] + stmt_str = "" + for x in node.fields: + vx = self.visit(x) + expr.append(vx.expr) + stmt_str += vx.stmt + list_name = self.fresh_local_name() + stmt_str += f"std::vector {list_name} = {{{inter(expr)}}};" + return ExprWithStmt(f"runtime::ADT::Tuple({list_name})", stmt_str) + + def visit_if(self, node): + vc = self.visit(node.cond) + vt = self.visit(node.true_branch) + vf = self.visit(node.false_branch) + ret_name = self.fresh_local_name() + stmt = f"ObjectRef {ret_name};" + stmt += f""" + {vc.stmt} + if (NDToBool(ObjectRefToND({vc.expr}))) {{ + {vt.stmt} + {ret_name} = {vt.expr}; + }} else {{ + {vf.stmt} + {ret_name} = {vf.expr}; + }} + """ + return ExprWithStmt(ret_name, stmt) + + def visit_constant(self, const): + if const not in self.declare_map: + name = self.fresh_global_name() + self.declare_map[const] = name + self.declare += f"ObjectRef {name};\n" + self.input_const.append((name, const.data.asnumpy())) + return ExprWithStmt(self.declare_map[const]) + + def visit_global_var(self, gv): + if gv not in self.declare_map: + name = self.fresh_global_name() + self.declare_map[gv] = f"{name}" + vgv = self.visit(self.gv_map[gv], local=False, name=name) + assert vgv.stmt == "" + assert vgv.expr == f"{name}" + return ExprWithStmt(self.declare_map[gv]) + + def visit_args(self, args): + args_str = "" + stmt_str = "" + for i, arg in enumerate(args): + va = self.visit(arg) + args_str += va.expr + stmt_str += va.stmt + if i != len(args) - 1: + args_str += ", " + return args_str, stmt_str + + def visit_invoke(self, invoke): + args_str, stmt_str = self.visit_args(invoke.args) + func = self.visit(invoke.call) + return ExprWithStmt(f"Apply({func.expr}, std::vector({{{args_str}}}))", stmt_str + func.stmt) + + def visit_decl(self, decl): + source = "" + for var, value in decl.bindings: + local_name = self.fresh_local_name(var) + self.name_map[var] = local_name + vv = self.visit(value, name=local_name) + source += vv.stmt + source += f"""ObjectRef {local_name} = {vv.expr};""" + vb = self.visit(decl.body) + source += vb.stmt + return ExprWithStmt(vb.expr, source) + + def nd_dtype(self, tt): + assert isinstance(tt, relay.ty.TensorType) + if tt.dtype == 'int32': + return 'dtype_i32' + elif tt.dtype == 'int8': + return 'dtype_i8' + elif tt.dtype == 'float32': + return 'dtype_f32' + elif tt.dtype == 'bool': + return 'dtype_u1' + raise Exception("unknown tensor dtype: " + str(tt)) + + def nd_shape(self, tt): + return f"{{{inter([str(s) for s in tt.shape])}}}" + + def visit_packed_call(self, call): + decl_str = "" + args = [] + for arg in call.args: + va = self.visit(arg) + decl_str += va.stmt + args.append(va.expr) + args_str = [] + def convert_input(ty, arg): + if isinstance(ty, relay.ty.TensorType): + args_str.append(f"{arg}") + else: + assert isinstance(ty, relay.ty.TupleType) + tuple_name = self.fresh_local_name() + nonlocal decl_str + decl_str += f"runtime::ADT {tuple_name} = Downcast({arg});\n" + for i, t in enumerate(ty.fields): + convert_input(t, f"{tuple_name}[{i}]") + assert len(call.args_type) == len(call.args) + for i in range(len(call.args_type)): + convert_input(call.args_type[i], args[i]) + + def convert_output(ty): + nonlocal decl_str + if isinstance(ty, relay.ty.TensorType): + tensor_name = self.fresh_local_name() + decl_str += f"NDArray {tensor_name} = NDArray::Empty({self.nd_shape(ty)}, {self.nd_dtype(ty)}, context);\n" + args_str.append(f"{tensor_name}") + return tensor_name + else: + assert isinstance(ty, relay.ty.TupleType) + list_name = self.fresh_local_name() + list_members = inter([convert_output(t) for t in ty.fields]) + decl_str += f"std::vector {list_name} = {{{list_members}}};" + return f"runtime::ADT::Tuple({list_name})" + out = convert_output(call.ret_type) + return ExprWithStmt(out, f""" + {decl_str} + const PackedFunc *pf = runtime::Registry::Get("{call.name}"); + CHECK(pf); + (*pf)({inter(args_str)}); + """) + + def visit_cpp_function(self, func, local, name): + vec = self.fresh_local_name() + body = "" + + end = len(func.params) - 1 + for i, param in enumerate(func.params): + pname = self.fresh_local_name(param) + self.name_map[param] = pname + body += f"ObjectRef {pname} = {vec}.at({i});\n" + + body += f"ObjectRef {name} = self;\n" + vb = self.visit(func.body) + body = body + vb.stmt + f"""return {vb.expr};""" + expr = f"""FunctionValueNode::make([=](const std::vector& {vec}, const ObjectRef& self) {{ + {body} + }}); + """ + + if local: + return ExprWithStmt(expr) + else: + if name is None: + name = self.fresh_global_name() + self.declare += f""" + static ObjectRef {name}_func() {{ + static ObjectRef ret = {expr}; + return ret; + }} + ObjectRef {name} = {name}_func(); + """ + return ExprWithStmt(f"{name}") + + def mk_register_api(self, name: str, func) -> str: + vf = self.visit(func, local=False) + assert vf.stmt == "" + source = self.declare + + args = "" + if isinstance(func, relay.GlobalVar): + func = self.gv_map[func] + end = len(func.params) - 1 + init = "" + for i, (input_name, _) in enumerate(self.input_const): + init += f"{input_name} = args[{i}];\n" + for i in range(len(func.params)): + args += f"args[{i+len(self.input_const)}]" + if i != end: + args += ", " + + source += f""" + TVM_REGISTER_GLOBAL("{name}") + .set_body([](TVMArgs args, TVMRetValue* ret) {{ + {init} + std::initializer_list ilist = {{{args}}}; + *ret = Apply({vf.expr}, std::vector(ilist)); + }}); + """ + return source + +def inter(strs, sep=", "): + ret = "" + for i in range(len(strs)): + ret += strs[i] + if i != len(strs) - 1: + ret += sep + return ret + +def mk_file(body, ctx): + return f""" + #include + #include + #include + #include + #include + #include + + using namespace tvm; + using namespace runtime; + using namespace relay; + + static DLDataType dtype_f32 = DLDataType {{ .code = DLDataTypeCode::kDLFloat, .bits = 32, .lanes = 1 }}; + static DLDataType dtype_u32 = DLDataType {{ .code = DLDataTypeCode::kDLUInt, .bits = 32, .lanes = 1 }}; + static DLDataType dtype_u1 = DLDataType {{ .code = DLDataTypeCode::kDLUInt, .bits = 1, .lanes = 1 }}; + static DLDataType dtype_i32 = DLDataType {{ .code = DLDataTypeCode::kDLInt, .bits = 32, .lanes = 1 }}; + static DLDataType dtype_i8 = DLDataType {{ .code = DLDataTypeCode::kDLInt, .bits = 8, .lanes = 1 }}; + static DLContext context = DLContext {{ .device_type = DLDeviceType({ctx.device_type}), .device_id = {ctx.device_id} }}; + + static bool NDToBool(const NDArray& nd) {{ + DLContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + NDArray cpu_array = nd.CopyTo(cpu_ctx); + CHECK_EQ(DataType(cpu_array->dtype), DataType::Bool()); + return reinterpret_cast(cpu_array->data)[0]; + }} + + static NDArray ObjectRefToND(const ObjectRef& v) {{ + return Downcast(v); + }} + + static ConstructorValue TagToCV(size_t tag, const tvm::Array& fields) {{ + ObjectPtr n = make_object(); + ObjectPtr con = make_object(); + con->tag = tag; + n->tag = tag; + n->constructor = Constructor(con); + n->fields = fields; + return ConstructorValue(n); + }} + + /*! \\brief A Function value. */ + class FunctionValue; + + using function_value_t = std::function&, const ObjectRef&)>; + struct FunctionValueNode : Object {{ + function_value_t f; + + FunctionValueNode() {{ }} + + void VisitAttrs(tvm::AttrVisitor* v) {{ }} + + TVM_DLL static FunctionValue make(const function_value_t& f); + + static constexpr const char* _type_key = "relay.FunctionValue"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionValueNode, Object); + }}; + + class FunctionValue : public ObjectRef {{ + public: + TVM_DEFINE_OBJECT_REF_METHODS(FunctionValue, ObjectRef, FunctionValueNode); + }}; + + FunctionValue FunctionValueNode::make(const function_value_t& f) {{ + ObjectPtr n = make_object(); + n->f = f; + return FunctionValue(n); + }} + + ObjectRef Apply(const ObjectRef& op, const std::vector& args) {{ + return Downcast(op)->f(args, op); + }} + + {body} + """ + +def to_source(mod, program, gv_map, ctx, name) -> str: + convert = ToSource(gv_map) + ret = mk_file(convert.mk_register_api(name, program), ctx) + return [value for name, value in convert.input_const], ret diff --git a/tests/python/relay/test_aot.py b/tests/python/relay/test_aot.py new file mode 100644 index 000000000000..20fef6854bdc --- /dev/null +++ b/tests/python/relay/test_aot.py @@ -0,0 +1,287 @@ +from tvm import relay +from tvm import IRModule as Module +from tvm.relay import var, Function, op, GlobalVar, TypeVar, FuncType +from tvm.relay.prelude import Prelude +from tvm.relay.testing import add_nat_definitions +import numpy as np +import tvm +from tvm.runtime import aot + + +def compile(f, mod): + tgt = tvm.target.create('llvm') + ctx = tvm.context('llvm', 0) + return aot.compile(f, mod, ctx=ctx, tgt=tgt) + + +def test_identity(): + mod = Module() + x = var('x', shape=()) + func = Function([x], x) + cfunc = compile(func, mod) + a = tvm.nd.array(np.array(1.0, dtype='float32')) + output = cfunc(a) + np.testing.assert_allclose(output.asnumpy(), a.asnumpy()) + + +def test_add(): + mod = Module() + x = var('x', shape=()) + y = var('y', shape=()) + z = x + y + func = Function([x, y], z) + cfunc = compile(func, mod) + a = tvm.nd.array(np.array(1.0, dtype='float32')) + b = tvm.nd.array(np.array(1.0, dtype='float32')) + c = tvm.nd.array(np.array(2.0, dtype='float32')) + output = cfunc(a, b) + np.testing.assert_allclose(output.asnumpy(), c.asnumpy()) + + +def test_mult_op(): + mod = Module() + x = var('x', shape=()) + y = var('y', shape=()) + z = x + y + zz = op.exp(z) + func = Function([x, y], zz) + cfunc = compile(func, mod) + a = tvm.nd.array(np.array(1.0, dtype='float32')) + b = tvm.nd.array(np.array(1.0, dtype='float32')) + output = cfunc(a, b) + np.testing.assert_allclose(output.asnumpy(), np.exp(a.asnumpy() + b.asnumpy())) + + +def test_double(): + mod = Module() + x = var('x', shape=()) + double = GlobalVar('double') + mod[double] = Function([x], x + x) + x = var('x', shape=()) + cfunc = compile(Function([x], double(double(x))), mod) + a = tvm.nd.array(np.array(1.5, dtype='float32')) + output = cfunc(a) + np.testing.assert_allclose(output.asnumpy(), np.array(6.0, dtype='float32')) + + +def test_42(): + mod = Module() + func = Function([], relay.const(42)) + cfunc = compile(func, mod) + output = cfunc() + np.testing.assert_allclose(output.asnumpy(), np.array(42.0, dtype='float32')) + + +def test_add_42(): + mod = Module() + x = var('x', shape=()) + func = Function([x], x + relay.const(42.0)) + cfunc = compile(func, mod) + a = tvm.nd.array(np.array(42.0, dtype='float32')) + output = cfunc(a) + np.testing.assert_allclose(output.asnumpy(), np.array(84.0, dtype='float32')) + + +def test_int_mult_3(): + mod = Module() + x = var('x', dtype='int32', shape=()) + func = Function([x], x * relay.const(3)) + cfunc = compile(func, mod) + a = tvm.nd.array(np.array(4, dtype='int32')) + output = cfunc(a) + np.testing.assert_allclose(output.asnumpy(), np.array(12, dtype='int32')) + + +def test_abs(): + mod = Module() + x = var('x', shape=()) + func = Function([x], relay.If(op.less(x, relay.const(0.0)), relay.const(-1.0) * x, x)) + cfunc = compile(func, mod) + a = tvm.nd.array(np.array(12.0, dtype='float32')) + output = cfunc(a) + np.testing.assert_allclose(output.asnumpy(), np.array(12.0, dtype='float32')) + a = tvm.nd.array(np.array(-34.0, dtype='float32')) + output = cfunc(a) + np.testing.assert_allclose(output.asnumpy(), np.array(34.0, dtype='float32')) + + +def test_recur_sum_global(): + mod = Module() + x = var('x', dtype='int32', shape=()) + sum = GlobalVar('sum') + c = relay.const(0) + mod[sum] = Function([x], + relay.If(op.less(x, c), c, x + sum(x - relay.const(1))), + relay.TensorType(dtype='int32', shape=())) + cfunc = compile(Function([], sum(relay.const(10))), mod) + output = cfunc() + np.testing.assert_allclose(output.asnumpy(), np.array(55, dtype='int32')) + + +def nat_to_int(n): + if n.constructor.tag & 0xff == 1: + return 1 + nat_to_int(n.fields[0]) + else: + assert n.constructor.tag & 0xff == 0 + return 0 + + +def int_to_nat(p, i): + if i > 0: + return p.s(int_to_nat(p, i - 1)) + else: + assert i == 0 + return p.z() + + +def test_nat_3(): + mod = Module() + p = Prelude(mod) + add_nat_definitions(p) + cfunc = compile(Function([], p.s(p.s(p.s(p.z())))), mod) + output = cfunc() + assert nat_to_int(output) == 3 + + +def test_nat_add(): + mod = Module() + p = Prelude(mod) + add_nat_definitions(p) + cfunc = compile(Function([], p.add(p.s(p.s(p.s(p.z()))), p.s(p.s(p.s(p.s(p.z())))))), mod) + output = cfunc() + assert nat_to_int(output) == 7 + + +def test_add_convert(): + mod = Module() + p = Prelude(mod) + add_nat_definitions(p) + cfunc = compile(p.add, mod) + output = cfunc(int_to_nat(p, 12), int_to_nat(p, 34)) + assert nat_to_int(output) == 46 + + +def test_ref(): + mod = Module() + three_with_ref = relay.GlobalVar('three_with_ref') + i = relay.Var('i') + iv = relay.Var('iv') + u = relay.Var('u') + uv = relay.Var('uv') + body = relay.add(iv, uv) + body = relay.Let(uv, relay.RefRead(i), body) + body = relay.Let(u, relay.RefWrite(i, relay.const(2, dtype='int32')), body) + body = relay.Let(iv, relay.RefRead(i), body) + body = relay.Let(i, relay.RefCreate(relay.const(1, dtype='int32')), body) + mod[three_with_ref] = relay.Function([], body) + cfunc = compile(three_with_ref, mod) + output = cfunc() + np.testing.assert_allclose(output.asnumpy(), np.array(3, dtype='int32')) + + +def test_tuple(): + mod = Module() + cfunc = compile(Function([], + relay.TupleGetItem(relay.Tuple([relay.const(3, dtype='int32'), + relay.const(4.0, dtype='float32')]), + 1)), + mod) + np.testing.assert_allclose(cfunc().asnumpy(), np.array(4.0, dtype='float32')) + + +def test_get_valid_counts(): + # Based on test_get_valid_counts in test_op_level5. + # Tests the case of a packed func returning a Relay tuple. + # Only checks the shapes of the output because the reference implementation + # is long and inconvenient. + shape = (1, 2500, 6) + score_threshold = 0 + id_index = 0 + score_index = 1 + np_data = np.random.uniform(low=-2, high=2, size=shape).astype("float32") + mod = Module() + cfunc = compile( + relay.Function( + [], + relay.vision.get_valid_counts( + relay.const(np_data), score_threshold, id_index, score_index + ).astuple()), + mod) + + relay_out = cfunc() + out1 = relay_out[0].asnumpy() + out2 = relay_out[1].asnumpy() + assert out1.shape == (shape[0],) + assert out2.shape == shape + + +def test_compose(): + mod = Module() + p = Prelude(mod) + add_nat_definitions(p) + x = relay.Var('x') + inc = GlobalVar('inc') + mod[inc] = Function([x], p.s(x)) + x = relay.Var('x') + func = GlobalVar('func') + f = Function([x], relay.Call(p.compose(inc, p.double), [x])) + mod[func] = f + cfunc = compile(func, mod) + assert nat_to_int(cfunc(p.s(p.s(p.z())))) == 5 + + +def test_recur_sum_local(): + mod = Module() + x = var('x', dtype='int32', shape=()) + t = relay.TensorType(dtype='int32', shape=()) + sum = relay.Var('sum', type_annotation=relay.FuncType([t], t)) + c = relay.const(0) + func = Function([x], + relay.If(op.less(x, c), c, x + sum(x - relay.const(1))), + t) + body = relay.Let(sum, func, sum(relay.const(10))) + cfunc = compile(Function([], body), mod) + output = cfunc() + np.testing.assert_allclose(output.asnumpy(), np.array(55, dtype='int32')) + + +def test_local_local_rec_outer_scope(): + mod = Module() + x = var('x', dtype='int32', shape=()) + t = relay.TensorType(dtype='int32', shape=()) + sum = relay.Var('sum', type_annotation=relay.FuncType([t], t)) + c = relay.const(0) + + # we define a locally recursive function inside another function's scope + # and have that function return the closure of the locally recursive function + inner_func = Function([x], + relay.If(op.less(x, c), c, x + sum(x - relay.const(1))), + t) + outer_func_body = relay.Let(sum, inner_func, sum) + outer_func = Function([], outer_func_body) + f = relay.Var('f') + body = relay.Let(f, outer_func(), f(relay.const(10))) + cfunc = compile(Function([], body), mod) + output = cfunc() + np.testing.assert_allclose(output.asnumpy(), np.array(55, dtype='int32')) + + +if __name__ == "__main__": + test_identity() + test_add() + test_mult_op() + test_double() + test_42() + test_add_42() + test_int_mult_3() + test_abs() + test_recur_sum_global() + test_nat_3() + test_nat_add() + test_add_convert() + test_ref() + test_tuple() + test_get_valid_counts() + test_compose() + test_recur_sum_local() + test_local_local_rec_outer_scope() From fd3579f76b8616a87baac0452ddd11bfeced8696 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 6 Aug 2020 16:33:58 -0700 Subject: [PATCH 2/4] Move AOT files to relay/backend --- python/tvm/{runtime => relay/backend}/aot/__init__.py | 0 python/tvm/{runtime => relay/backend}/aot/aot.py | 0 python/tvm/{runtime => relay/backend}/aot/convert.py | 0 python/tvm/{runtime => relay/backend}/aot/little_cpp.py | 0 python/tvm/{runtime => relay/backend}/aot/to_source.py | 0 tests/python/relay/test_aot.py | 7 ++++--- 6 files changed, 4 insertions(+), 3 deletions(-) rename python/tvm/{runtime => relay/backend}/aot/__init__.py (100%) rename python/tvm/{runtime => relay/backend}/aot/aot.py (100%) rename python/tvm/{runtime => relay/backend}/aot/convert.py (100%) rename python/tvm/{runtime => relay/backend}/aot/little_cpp.py (100%) rename python/tvm/{runtime => relay/backend}/aot/to_source.py (100%) diff --git a/python/tvm/runtime/aot/__init__.py b/python/tvm/relay/backend/aot/__init__.py similarity index 100% rename from python/tvm/runtime/aot/__init__.py rename to python/tvm/relay/backend/aot/__init__.py diff --git a/python/tvm/runtime/aot/aot.py b/python/tvm/relay/backend/aot/aot.py similarity index 100% rename from python/tvm/runtime/aot/aot.py rename to python/tvm/relay/backend/aot/aot.py diff --git a/python/tvm/runtime/aot/convert.py b/python/tvm/relay/backend/aot/convert.py similarity index 100% rename from python/tvm/runtime/aot/convert.py rename to python/tvm/relay/backend/aot/convert.py diff --git a/python/tvm/runtime/aot/little_cpp.py b/python/tvm/relay/backend/aot/little_cpp.py similarity index 100% rename from python/tvm/runtime/aot/little_cpp.py rename to python/tvm/relay/backend/aot/little_cpp.py diff --git a/python/tvm/runtime/aot/to_source.py b/python/tvm/relay/backend/aot/to_source.py similarity index 100% rename from python/tvm/runtime/aot/to_source.py rename to python/tvm/relay/backend/aot/to_source.py diff --git a/tests/python/relay/test_aot.py b/tests/python/relay/test_aot.py index 20fef6854bdc..abbceb62b803 100644 --- a/tests/python/relay/test_aot.py +++ b/tests/python/relay/test_aot.py @@ -1,11 +1,12 @@ +import numpy as np + +import tvm from tvm import relay from tvm import IRModule as Module from tvm.relay import var, Function, op, GlobalVar, TypeVar, FuncType from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions -import numpy as np -import tvm -from tvm.runtime import aot +from tvm.relay.backend import aot def compile(f, mod): From 8961bcb091b6c5d2727455db3bb74dad974c253f Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 6 Aug 2020 16:56:46 -0700 Subject: [PATCH 3/4] Add ASF headers to AoT files --- python/tvm/relay/backend/aot/__init__.py | 17 +++++++++++++++++ python/tvm/relay/backend/aot/aot.py | 17 +++++++++++++++++ python/tvm/relay/backend/aot/convert.py | 17 +++++++++++++++++ python/tvm/relay/backend/aot/little_cpp.py | 17 +++++++++++++++++ python/tvm/relay/backend/aot/to_source.py | 17 +++++++++++++++++ tests/python/relay/test_aot.py | 17 +++++++++++++++++ 6 files changed, 102 insertions(+) diff --git a/python/tvm/relay/backend/aot/__init__.py b/python/tvm/relay/backend/aot/__init__.py index b02a1c672b81..8a125be14c8f 100644 --- a/python/tvm/relay/backend/aot/__init__.py +++ b/python/tvm/relay/backend/aot/__init__.py @@ -1 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from .aot import compile diff --git a/python/tvm/relay/backend/aot/aot.py b/python/tvm/relay/backend/aot/aot.py index d1a9ff2cc963..6ca433008b61 100644 --- a/python/tvm/relay/backend/aot/aot.py +++ b/python/tvm/relay/backend/aot/aot.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import ctypes import os import subprocess diff --git a/python/tvm/relay/backend/aot/convert.py b/python/tvm/relay/backend/aot/convert.py index 5daaa83b66a8..dc51e16b0783 100644 --- a/python/tvm/relay/backend/aot/convert.py +++ b/python/tvm/relay/backend/aot/convert.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import numpy as np import tvm from tvm import relay diff --git a/python/tvm/relay/backend/aot/little_cpp.py b/python/tvm/relay/backend/aot/little_cpp.py index f3b819d95672..8348e9c92dfa 100644 --- a/python/tvm/relay/backend/aot/little_cpp.py +++ b/python/tvm/relay/backend/aot/little_cpp.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from tvm.relay import Var, TypeVar from typing import Any, Optional, List, Tuple import attr diff --git a/python/tvm/relay/backend/aot/to_source.py b/python/tvm/relay/backend/aot/to_source.py index a5aa373d9b4a..96f6bc62b2db 100644 --- a/python/tvm/relay/backend/aot/to_source.py +++ b/python/tvm/relay/backend/aot/to_source.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from . import little_cpp from tvm import relay from tvm.relay.prelude import Prelude diff --git a/tests/python/relay/test_aot.py b/tests/python/relay/test_aot.py index abbceb62b803..352059039ff9 100644 --- a/tests/python/relay/test_aot.py +++ b/tests/python/relay/test_aot.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import numpy as np import tvm From 44358e297b3f2007a3262e293d83c73c8b09c73b Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 6 Aug 2020 18:45:15 -0700 Subject: [PATCH 4/4] Fix tons of pylint errors --- python/tvm/relay/backend/aot/__init__.py | 8 +- python/tvm/relay/backend/aot/aot.py | 102 +++---- python/tvm/relay/backend/aot/convert.py | 17 +- python/tvm/relay/backend/aot/little_cpp.py | 8 +- python/tvm/relay/backend/aot/to_source.py | 296 +++++++++++++-------- tests/python/relay/test_aot.py | 2 +- 6 files changed, 267 insertions(+), 166 deletions(-) diff --git a/python/tvm/relay/backend/aot/__init__.py b/python/tvm/relay/backend/aot/__init__.py index 8a125be14c8f..888b331ae092 100644 --- a/python/tvm/relay/backend/aot/__init__.py +++ b/python/tvm/relay/backend/aot/__init__.py @@ -15,4 +15,10 @@ # specific language governing permissions and limitations # under the License. -from .aot import compile +""" +This module defines the Relay ahead-of-time (AoT) compiler, +which translates Relay ASTs into C++ code that calls into +already-compiled operators. These end-to-end compiled +programs can in principle run without a runtime. +""" +from .aot import compile_prog diff --git a/python/tvm/relay/backend/aot/aot.py b/python/tvm/relay/backend/aot/aot.py index 6ca433008b61..f66e1c49f797 100644 --- a/python/tvm/relay/backend/aot/aot.py +++ b/python/tvm/relay/backend/aot/aot.py @@ -14,30 +14,39 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +""" +Defines the entry point into the AoT compiler. +""" import ctypes import os import subprocess import tempfile +import time + import tvm -from tvm import relay, get_global_func, target, register_func +from tvm import relay, get_global_func, register_func from tvm.relay.function import Function from tvm.relay.expr import Expr, Let, GlobalVar from tvm.relay.adt import Constructor -from tvm.relay.expr_functor import ExprFunctor, ExprVisitor +from tvm.relay.expr_functor import ExprFunctor from tvm.relay.backend import compile_engine -from .little_cpp import PackedCall, CPPFunction, Invoke, Decl, CPPIf, CPPTuple, CPPMatch, CPPConstructor, CPPTupleGetItem -from .little_cpp import CPPRefCreate, CPPRefRead, CPPRefWrite +from .little_cpp import (PackedCall, CPPFunction, Invoke, Decl, CPPIf, + CPPTuple, CPPMatch, CPPConstructor, CPPTupleGetItem, + CPPRefCreate, CPPRefRead, CPPRefWrite) from . import to_source from .convert import convert TVM_PATH = os.environ['TVM_HOME'] def must_run_process(args): - proc = subprocess.run(args) + proc = subprocess.run(args, check=True) assert proc.returncode == 0 def compile_cpp(source, lib_name, flags=None, lib_path=None): + """ + Compiles the given source into a C++ library + and returns the full path to the compiled library. + """ if flags is None: flags = [] @@ -59,37 +68,27 @@ def compile_cpp(source, lib_name, flags=None, lib_path=None): must_run_process(["clang-format", "-i", debug_source_path]) system = os.uname()[0] + include_paths = [ + f"-I{TVM_PATH}/3rdparty/dmlc-core/include", + f"-I{TVM_PATH}/3rdparty/dlpack/include", + f"-I{TVM_PATH}/3rdparty/HalideIR/src", + f"-I{TVM_PATH}/include", + f"-L{TVM_PATH}/build" + ] + if system == 'Darwin': command = [ - "clang", - "-std=c++14", - "-shared", - "-undefined", - "dynamic_lookup", - "-o", - lib_path, + "clang", "-std=c++14", "-shared", "-undefined", "dynamic_lookup", + "-o", lib_path, source_path, - f"-I{TVM_PATH}/3rdparty/dmlc-core/include", - f"-I{TVM_PATH}/3rdparty/dlpack/include", - f"-I{TVM_PATH}/3rdparty/HalideIR/src", - f"-I{TVM_PATH}/include", - f"-L{TVM_PATH}/build", + *include_paths, "-ltvm" ] + flags else: command = [ - "clang", - "-std=c++14", - "-shared", - "-fPIC", - "-o", - lib_path, + "clang", "-std=c++14", "-shared", "-fPIC", "-o", lib_path, source_path, - f"-I{TVM_PATH}/3rdparty/dmlc-core/include", - f"-I{TVM_PATH}/3rdparty/dlpack/include", - f"-I{TVM_PATH}/3rdparty/HalideIR/src", - f"-I{TVM_PATH}/include", - f"-L{TVM_PATH}/build", + *include_paths, "-ltvm" ] + flags @@ -99,10 +98,16 @@ def compile_cpp(source, lib_name, flags=None, lib_path=None): def load_lib(name): return ctypes.CDLL(name, ctypes.RTLD_GLOBAL) -def is_primitive(e: relay.Expr): - return isinstance(e, relay.Function) and e.attrs and e.attrs.Primitive.value == 1 +def is_primitive(expr: relay.Expr): + return (isinstance(expr, relay.Function) + and expr.attrs + and expr.attrs.Primitive.value == 1) class AoTCompiler(ExprFunctor): + """ + Takes a Relay program and converts into a Little CPP program + that can in turn be converted into C++ source code. + """ def __init__(self, mod, tgt) -> None: super().__init__() self.mod = mod @@ -126,8 +131,8 @@ def optimize(self, expr: Function) -> Function: def mk_primitive_op(self, func: Expr, args, output_type) -> Expr: cc_key = compile_engine.CCacheKey(func, self.tgt) - hash = tvm.ir.structural_hash(func) - name = f"op_{hash}" + func_hash = tvm.ir.structural_hash(func) + name = f"op_{func_hash}" if not get_global_func(name, allow_missing=True): jit_func = self.engine.jit(cc_key, self.tgt) register_func(name, jit_func) @@ -136,13 +141,12 @@ def mk_primitive_op(self, func: Expr, args, output_type) -> Expr: def visit_call(self, call: Expr) -> Expr: if is_primitive(call.op): return self.mk_primitive_op(call.op, call.args, call.checked_type) - elif isinstance(call.op, Constructor): + if isinstance(call.op, Constructor): return CPPConstructor(call.op.tag, [self.visit(arg) for arg in call.args]) - else: - assert(call.attrs == None) - args = [self.visit(arg) for arg in call.args] - fn = self.visit(call.op) - return Invoke(fn, args) + assert call.attrs is None + args = [self.visit(arg) for arg in call.args] + func = self.visit(call.op) + return Invoke(func, args) def visit_let(self, let: Expr) -> Expr: self.bindings.append([]) @@ -170,8 +174,7 @@ def visit_function(self, func): if is_primitive(func): body = self.mk_primitive_op(func, func.params, func.ret_type) return CPPFunction(func.params, body, func.checked_type.ret_type) - else: - return CPPFunction(func.params, self.visit(func.body), func.checked_type.ret_type) + return CPPFunction(func.params, self.visit(func.body), func.checked_type.ret_type) def visit_constant(self, const): return const @@ -193,6 +196,9 @@ def visit_match(self, m): def visit_op(self, op): raise Exception(f'op outside of primitive: {op}') + def visit_constructor(self, ctor): + raise Exception('Constructors should be handled when visiting calls.') + def visit_tuple_getitem(self, t): return CPPTupleGetItem(self.visit(t.tuple_value), t.index, t.checked_type) @@ -215,19 +221,17 @@ def lib_and_func_name(name): _LIB_COUNTER += 1 return lib_name, packed_name -import time - -def _mk_wrapper(fn, ctx, constants, record_time): +def _mk_wrapper(func, ctx, constants, record_time): def _wrapper(*args): new_constants = [convert(a, ctx) for a in constants] new_args = [convert(a, ctx) for a in args] begin = time.perf_counter() - res = fn(*new_constants, *new_args) + res = func(*new_constants, *new_args) end = time.perf_counter() return res if not record_time else (res, end - begin) return _wrapper -def compile(func, mod, ctx, tgt, name='default', record_time=False): +def compile_prog(func, mod, ctx, tgt, name='default', record_time=False): """Compile a Relay function into a C++ file that implements a program with the same semantics, which calls into TVM only for operators. @@ -270,9 +274,9 @@ def compile(func, mod, ctx, tgt, name='default', record_time=False): func = compiler.optimize(func) func = compiler.visit(func) lib_name, packed_name = lib_and_func_name(name) - constants, source_code = to_source.to_source(mod, func, compiler.gv_map, ctx, packed_name) + constants, source_code = to_source.to_source(func, compiler.gv_map, ctx, packed_name) lib_name = f"librelay_aot_{_LIB_COUNTER}.so" library_path = compile_cpp(source_code, lib_name, flags=["-O3"]) _LIB.append(load_lib(library_path)) - fn = get_global_func(packed_name) - return _mk_wrapper(fn, ctx, constants, record_time) + func = get_global_func(packed_name) + return _mk_wrapper(func, ctx, constants, record_time) diff --git a/python/tvm/relay/backend/aot/convert.py b/python/tvm/relay/backend/aot/convert.py index dc51e16b0783..be4da397bb6b 100644 --- a/python/tvm/relay/backend/aot/convert.py +++ b/python/tvm/relay/backend/aot/convert.py @@ -14,13 +14,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +""" +Responsible for converting function arguments into +a form that can be passed to a `PackedFunc`. +""" import numpy as np import tvm from tvm import relay -# convert(convert(a, tg), tg) = convert(a, tg) def convert(a, ctx): + """ + Converts a function input `a` + (which may take constant defined in Relay, numpy arrays, + or TVM NDArrays) + into a form that can be passed to a TVM `PackedFunc` + with the given context. + """ + # convert(convert(a, tg), tg) = convert(a, tg) while True: if isinstance(a, int): a = np.array(a, dtype='int32') @@ -33,7 +43,8 @@ def convert(a, ctx): a = (a.op, *a.args) elif isinstance(a, tuple): assert isinstance(a[0], relay.Constructor) - a = relay.backend.interpreter.ConstructorValue(a[0].tag, [convert(arg, ctx) for arg in a[1:]], a[0]) + a = relay.backend.interpreter.ConstructorValue( + a[0].tag, [convert(arg, ctx) for arg in a[1:]], a[0]) elif isinstance(a, relay.backend.interpreter.ConstructorValue): return a else: diff --git a/python/tvm/relay/backend/aot/little_cpp.py b/python/tvm/relay/backend/aot/little_cpp.py index 8348e9c92dfa..d72fbd8e4d3e 100644 --- a/python/tvm/relay/backend/aot/little_cpp.py +++ b/python/tvm/relay/backend/aot/little_cpp.py @@ -14,10 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -from tvm.relay import Var, TypeVar +""" +Defines the Little CPP intermediate representation used by the AoT compiler +(corresponds to a small subset of the C++ AST +and a couple of TVM-specific concepts). +""" from typing import Any, Optional, List, Tuple import attr +from tvm.relay import Var class LittleCppNode: pass diff --git a/python/tvm/relay/backend/aot/to_source.py b/python/tvm/relay/backend/aot/to_source.py index 96f6bc62b2db..3eb2b1e04193 100644 --- a/python/tvm/relay/backend/aot/to_source.py +++ b/python/tvm/relay/backend/aot/to_source.py @@ -14,12 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -from . import little_cpp +""" +Responsible for taking Little CPP ASTs and converting them +into C++ source code. +""" from tvm import relay -from tvm.relay.prelude import Prelude +from . import little_cpp class ExprWithStmt: + """ + Representation of an expression that requires a + C++ statement to define terms used in it. + """ def __init__(self, expr, stmt=""): assert isinstance(expr, str) assert isinstance(stmt, str) @@ -35,6 +41,9 @@ def __repr__(self): return self.__str__() class ToSource: + """ + Handles converting Little CPP ASTs into C++ source code. + """ def __init__(self, gv_map): self.gv_map = gv_map self.name_counter = 0 @@ -50,8 +59,8 @@ def fresh_global_name(self): self.name_counter += 1 return name - def sanitize(self, str): - return str.replace("-", "_").replace("/", "_") + def sanitize(self, name): + return name.replace("-", "_").replace("/", "_") def fresh_local_name(self, var=None): if var is not None: @@ -66,8 +75,32 @@ def fresh_label_name(self): self.name_counter += 1 return name - # return (str, str) with lhs being stmts, and rhs being expression def visit(self, node, *, local=True, name=None): + """ + Visits a Little CPP node and returns C++ code as text + in the form of statements with the necessary definitions + and an expression corresponding to the result of the node. + + Parameters + ---------- + node: Little CPP node to be compiled + + local: Optional[bool] + For function definitions, specifies whether to + treat it as a local definition (if True) or global (if False) + + name: Optional[str] + Specifies a name for the node if it's a function + definition (otherwise the compiler generates a name) + + Returns + ------- + result: ExprWithStmt + Contains a C++ expression corresponding + to the result of `node` (as a string) + with one or more C++ statements (as strings) + containing needed definitions. + """ if isinstance(node, little_cpp.PackedCall): res = self.visit_packed_call(node) elif isinstance(node, little_cpp.CPPFunction): @@ -104,28 +137,36 @@ def visit(self, node, *, local=True, name=None): return res def visit_ref_create(self, node): - vv = self.visit(node.value) - return ExprWithStmt(f"RefValue({vv.expr})", vv.stmt) + value = self.visit(node.value) + return ExprWithStmt(f"RefValue({value.expr})", value.stmt) def visit_ref_read(self, node): - vr = self.visit(node.ref) - return ExprWithStmt(f"Downcast({vr.expr})->value", vr.stmt) + ref = self.visit(node.ref) + return ExprWithStmt(f"Downcast({ref.expr})->value", ref.stmt) def visit_ref_write(self, node): - vr = self.visit(node.ref) - vv = self.visit(node.value) - stmt = vr.stmt + vv.stmt + f"Downcast({vr.expr})->value={vv.expr};\n" + ref = self.visit(node.ref) + value = self.visit(node.value) + stmt = ref.stmt + value.stmt + f"Downcast({ref.expr})->value={value.expr};\n" return ExprWithStmt("runtime::ADT::Tuple()", stmt) def visit_tuple_getitem(self, node): - vt = self.visit(node.tuple_value) - return ExprWithStmt(f"Downcast({vt.expr})[{node.index}]", vt.stmt) + visit_tup = self.visit(node.tuple_value) + return ExprWithStmt(f"Downcast" + f"({visit_tup.expr})" + f"[{node.index}]", visit_tup.stmt) def visit_constructor(self, node): - args_str, stmt_str = self.visit_args(node.fields) + args_str, _ = self.visit_args(node.fields) return ExprWithStmt(f"TagToCV({node.tag}, {{{args_str}}})") def pattern_var(self, pat, var_set): + """ + Given a match pattern `pat` and a set of variable names `var_set`, + adds the variables appearing in `pat` to `var_set` and + raises an exception if any is already in the set + (the names should be distinct). + """ if isinstance(pat, relay.PatternConstructor): for x in pat.patterns: self.pattern_var(x, var_set) @@ -136,8 +177,11 @@ def pattern_var(self, pat, var_set): raise Exception(str(pat)) def visit_match(self, node): - vd = self.visit(node.data) - stmt_str = vd.stmt + """ + Handle a match expression. + """ + data = self.visit(node.data) + stmt_str = data.stmt pattern_var_set = set() for c in node.clause: @@ -156,13 +200,13 @@ def visit_pattern(pat, data_name, fail_label, ok_label): ok_case = "" bind_names = [] assert len(pat.constructor.inputs) == len(pat.patterns) - for i, input_type in enumerate(pat.constructor.inputs): + for i, _ in enumerate(pat.constructor.inputs): bind_name = self.fresh_local_name() bind_names.append(bind_name) ok_case += f"ObjectRef {bind_name} = {data_name}->fields[{i}];\n" - for bind_name, p in zip(bind_names, pat.patterns): + for bind_name, pattern in zip(bind_names, pat.patterns): next_label = self.fresh_label_name() - ok_case += visit_pattern(p, bind_name, fail_label, next_label) + ok_case += visit_pattern(pattern, bind_name, fail_label, next_label) ok_case += f"{next_label}:\n" ok_case += f"goto {ok_label};" return f""" @@ -173,30 +217,31 @@ def visit_pattern(pat, data_name, fail_label, ok_label): goto {fail_label}; }} """ - elif isinstance(pat, relay.PatternVar): + + if isinstance(pat, relay.PatternVar): return f""" {self.name_map[pat.var]} = {data_name}; """ - else: - raise Exception(str(pat)) + + raise Exception(str(pat)) in_name = self.fresh_local_name() out_name = self.fresh_local_name() - stmt_str += f"ObjectRef {in_name} = {vd.expr};\n" + stmt_str += f"ObjectRef {in_name} = {data.expr};\n" stmt_str += f"ObjectRef {out_name};\n" match_finish_label = self.fresh_label_name() - for c in node.clause: - vc = self.visit(c[1]) + for clause in node.clause: + clause_value = self.visit(clause[1]) fail_label = self.fresh_label_name() ok_label = self.fresh_label_name() stmt_str += f"""{{ - {visit_pattern(c[0], in_name, fail_label, ok_label)} + {visit_pattern(clause[0], in_name, fail_label, ok_label)} }} """ stmt_str += f"""{{ {ok_label}: - {vc.stmt} - {out_name} = {vc.expr}; + {clause_value.stmt} + {out_name} = {clause_value.expr}; goto {match_finish_label}; }} """ @@ -208,28 +253,31 @@ def visit_pattern(pat, data_name, fail_label, ok_label): def visit_tuple(self, node): expr = [] stmt_str = "" - for x in node.fields: - vx = self.visit(x) - expr.append(vx.expr) - stmt_str += vx.stmt + for field in node.fields: + visit_field = self.visit(field) + expr.append(visit_field.expr) + stmt_str += visit_field.stmt list_name = self.fresh_local_name() stmt_str += f"std::vector {list_name} = {{{inter(expr)}}};" return ExprWithStmt(f"runtime::ADT::Tuple({list_name})", stmt_str) def visit_if(self, node): - vc = self.visit(node.cond) - vt = self.visit(node.true_branch) - vf = self.visit(node.false_branch) + """ + Handle an if-else expression. + """ + cond = self.visit(node.cond) + true_branch = self.visit(node.true_branch) + false_branch = self.visit(node.false_branch) ret_name = self.fresh_local_name() stmt = f"ObjectRef {ret_name};" stmt += f""" - {vc.stmt} - if (NDToBool(ObjectRefToND({vc.expr}))) {{ - {vt.stmt} - {ret_name} = {vt.expr}; + {cond.stmt} + if (NDToBool(ObjectRefToND({cond.expr}))) {{ + {true_branch.stmt} + {ret_name} = {true_branch.expr}; }} else {{ - {vf.stmt} - {ret_name} = {vf.expr}; + {false_branch.stmt} + {ret_name} = {false_branch.expr}; }} """ return ExprWithStmt(ret_name, stmt) @@ -242,22 +290,23 @@ def visit_constant(self, const): self.input_const.append((name, const.data.asnumpy())) return ExprWithStmt(self.declare_map[const]) - def visit_global_var(self, gv): - if gv not in self.declare_map: + def visit_global_var(self, global_var): + if global_var not in self.declare_map: name = self.fresh_global_name() - self.declare_map[gv] = f"{name}" - vgv = self.visit(self.gv_map[gv], local=False, name=name) - assert vgv.stmt == "" - assert vgv.expr == f"{name}" - return ExprWithStmt(self.declare_map[gv]) + self.declare_map[global_var] = f"{name}" + visit_gv = self.visit(self.gv_map[global_var], + local=False, name=name) + assert visit_gv.stmt == "" + assert visit_gv.expr == f"{name}" + return ExprWithStmt(self.declare_map[global_var]) def visit_args(self, args): args_str = "" stmt_str = "" for i, arg in enumerate(args): - va = self.visit(arg) - args_str += va.expr - stmt_str += va.stmt + visit_arg = self.visit(arg) + args_str += visit_arg.expr + stmt_str += visit_arg.stmt if i != len(args) - 1: args_str += ", " return args_str, stmt_str @@ -265,70 +314,87 @@ def visit_args(self, args): def visit_invoke(self, invoke): args_str, stmt_str = self.visit_args(invoke.args) func = self.visit(invoke.call) - return ExprWithStmt(f"Apply({func.expr}, std::vector({{{args_str}}}))", stmt_str + func.stmt) + return ExprWithStmt( + f"Apply({func.expr}, std::vector({{{args_str}}}))", + stmt_str + func.stmt) def visit_decl(self, decl): + """ + Handles a declaration. + """ source = "" for var, value in decl.bindings: local_name = self.fresh_local_name(var) self.name_map[var] = local_name - vv = self.visit(value, name=local_name) - source += vv.stmt - source += f"""ObjectRef {local_name} = {vv.expr};""" - vb = self.visit(decl.body) - source += vb.stmt - return ExprWithStmt(vb.expr, source) - - def nd_dtype(self, tt): - assert isinstance(tt, relay.ty.TensorType) - if tt.dtype == 'int32': + visited_value = self.visit(value, name=local_name) + source += visited_value.stmt + source += f"""ObjectRef {local_name} = {visited_value.expr};""" + body = self.visit(decl.body) + source += body.stmt + return ExprWithStmt(body.expr, source) + + def nd_dtype(self, tensor_type): + """Given a Relay tensor type, returns the appropriate dtype name""" + assert isinstance(tensor_type, relay.ty.TensorType) + if tensor_type.dtype == 'int32': return 'dtype_i32' - elif tt.dtype == 'int8': + if tensor_type.dtype == 'int8': return 'dtype_i8' - elif tt.dtype == 'float32': + if tensor_type.dtype == 'float32': return 'dtype_f32' - elif tt.dtype == 'bool': + if tensor_type.dtype == 'bool': return 'dtype_u1' - raise Exception("unknown tensor dtype: " + str(tt)) + raise Exception("unknown tensor dtype: " + str(tensor_type)) - def nd_shape(self, tt): - return f"{{{inter([str(s) for s in tt.shape])}}}" + def nd_shape(self, tensor_type): + """ + Given a Relay tensor type, returns its shape. + """ + return f"{{{inter([str(s) for s in tensor_type.shape])}}}" def visit_packed_call(self, call): + """ + Handle a call to a PackedFunc. + """ decl_str = "" args = [] for arg in call.args: - va = self.visit(arg) - decl_str += va.stmt - args.append(va.expr) + visit_arg = self.visit(arg) + decl_str += visit_arg.stmt + args.append(visit_arg.expr) args_str = [] - def convert_input(ty, arg): - if isinstance(ty, relay.ty.TensorType): + + def convert_input(input_ty, arg): + if isinstance(input_ty, relay.ty.TensorType): args_str.append(f"{arg}") else: - assert isinstance(ty, relay.ty.TupleType) + assert isinstance(input_ty, relay.ty.TupleType) tuple_name = self.fresh_local_name() nonlocal decl_str - decl_str += f"runtime::ADT {tuple_name} = Downcast({arg});\n" - for i, t in enumerate(ty.fields): + decl_str += (f"runtime::ADT {tuple_name} =" + f" Downcast({arg});\n") + for i, t in enumerate(input_ty.fields): convert_input(t, f"{tuple_name}[{i}]") assert len(call.args_type) == len(call.args) for i in range(len(call.args_type)): convert_input(call.args_type[i], args[i]) - def convert_output(ty): + def convert_output(output_ty): nonlocal decl_str - if isinstance(ty, relay.ty.TensorType): + if isinstance(output_ty, relay.ty.TensorType): tensor_name = self.fresh_local_name() - decl_str += f"NDArray {tensor_name} = NDArray::Empty({self.nd_shape(ty)}, {self.nd_dtype(ty)}, context);\n" + decl_str += (f"NDArray {tensor_name} = " + f"NDArray::Empty({self.nd_shape(output_ty)}, " + f"{self.nd_dtype(output_ty)}, context);\n") args_str.append(f"{tensor_name}") return tensor_name - else: - assert isinstance(ty, relay.ty.TupleType) - list_name = self.fresh_local_name() - list_members = inter([convert_output(t) for t in ty.fields]) - decl_str += f"std::vector {list_name} = {{{list_members}}};" - return f"runtime::ADT::Tuple({list_name})" + + assert isinstance(output_ty, relay.ty.TupleType) + list_name = self.fresh_local_name() + list_members = inter([convert_output(t) for t in output_ty.fields]) + decl_str += f"std::vector {list_name} = {{{list_members}}};" + return f"runtime::ADT::Tuple({list_name})" + out = convert_output(call.ret_type) return ExprWithStmt(out, f""" {decl_str} @@ -338,40 +404,50 @@ def convert_output(ty): """) def visit_cpp_function(self, func, local, name): + """ + Handle a Little CPP function. + """ vec = self.fresh_local_name() body = "" - end = len(func.params) - 1 for i, param in enumerate(func.params): pname = self.fresh_local_name(param) self.name_map[param] = pname body += f"ObjectRef {pname} = {vec}.at({i});\n" body += f"ObjectRef {name} = self;\n" - vb = self.visit(func.body) - body = body + vb.stmt + f"""return {vb.expr};""" - expr = f"""FunctionValueNode::make([=](const std::vector& {vec}, const ObjectRef& self) {{ - {body} - }}); + visit_body = self.visit(func.body) + body = body + visit_body.stmt + f"""return {visit_body.expr};""" + expr = f""" + FunctionValueNode::make([=]( + const std::vector& {vec}, + const ObjectRef& self) {{ + {body} + }}); """ if local: return ExprWithStmt(expr) - else: - if name is None: - name = self.fresh_global_name() - self.declare += f""" - static ObjectRef {name}_func() {{ - static ObjectRef ret = {expr}; - return ret; - }} - ObjectRef {name} = {name}_func(); - """ - return ExprWithStmt(f"{name}") + + if name is None: + name = self.fresh_global_name() + + self.declare += f""" + static ObjectRef {name}_func() {{ + static ObjectRef ret = {expr}; + return ret; + }} + ObjectRef {name} = {name}_func(); + """ + return ExprWithStmt(f"{name}") def mk_register_api(self, name: str, func) -> str: - vf = self.visit(func, local=False) - assert vf.stmt == "" + """ + Converts the given Little CPP function into C++ text + and registers the produced C++ function as a TVM `PackedFunc` under the given name. + """ + visited_func = self.visit(func, local=False) + assert visited_func.stmt == "" source = self.declare args = "" @@ -391,15 +467,15 @@ def mk_register_api(self, name: str, func) -> str: .set_body([](TVMArgs args, TVMRetValue* ret) {{ {init} std::initializer_list ilist = {{{args}}}; - *ret = Apply({vf.expr}, std::vector(ilist)); + *ret = Apply({visited_func.expr}, std::vector(ilist)); }}); """ return source def inter(strs, sep=", "): ret = "" - for i in range(len(strs)): - ret += strs[i] + for i, string in enumerate(strs): + ret += string if i != len(strs) - 1: ret += sep return ret @@ -482,7 +558,7 @@ class FunctionValue : public ObjectRef {{ {body} """ -def to_source(mod, program, gv_map, ctx, name) -> str: +def to_source(program, gv_map, ctx, name) -> str: convert = ToSource(gv_map) ret = mk_file(convert.mk_register_api(name, program), ctx) return [value for name, value in convert.input_const], ret diff --git a/tests/python/relay/test_aot.py b/tests/python/relay/test_aot.py index 352059039ff9..e8021a52f0ce 100644 --- a/tests/python/relay/test_aot.py +++ b/tests/python/relay/test_aot.py @@ -29,7 +29,7 @@ def compile(f, mod): tgt = tvm.target.create('llvm') ctx = tvm.context('llvm', 0) - return aot.compile(f, mod, ctx=ctx, tgt=tgt) + return aot.compile_prog(f, mod, ctx=ctx, tgt=tgt) def test_identity():