Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Runtime][WIP] Add prototype Relay AoT compiler directly into TVM #6219

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions python/tvm/relay/backend/aot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
This module defines the Relay ahead-of-time (AoT) compiler,
which translates Relay ASTs into C++ code that calls into
already-compiled operators. These end-to-end compiled
programs can in principle run without a runtime.
"""
from .aot import compile_prog
282 changes: 282 additions & 0 deletions python/tvm/relay/backend/aot/aot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Defines the entry point into the AoT compiler.
"""
import ctypes
import os
import subprocess
import tempfile
import time

import tvm
from tvm import relay, get_global_func, register_func
from tvm.relay.function import Function
from tvm.relay.expr import Expr, Let, GlobalVar
from tvm.relay.adt import Constructor
from tvm.relay.expr_functor import ExprFunctor
from tvm.relay.backend import compile_engine
from .little_cpp import (PackedCall, CPPFunction, Invoke, Decl, CPPIf,
CPPTuple, CPPMatch, CPPConstructor, CPPTupleGetItem,
CPPRefCreate, CPPRefRead, CPPRefWrite)
from . import to_source
from .convert import convert

TVM_PATH = os.environ['TVM_HOME']

def must_run_process(args):
proc = subprocess.run(args, check=True)
assert proc.returncode == 0

def compile_cpp(source, lib_name, flags=None, lib_path=None):
"""
Compiles the given source into a C++ library
and returns the full path to the compiled library.
"""
if flags is None:
flags = []

if lib_path is None:
lib_path = os.curdir

debug_source_path = os.path.join(lib_path, 'source.cc')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put this functionality behind a debug flag?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

# Write out the file for debugging.
with open(debug_source_path, 'w') as source_file:
source_file.write(source)

# with tempfile.TmporaryDirectory() as tmpdir:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

tmpdir = tempfile.mkdtemp(prefix="relay_aot_compiler")
lib_path = os.path.join(tmpdir, lib_name)
source_path = os.path.join(tmpdir, 'source.cc')
with open(source_path, 'w') as source_file:
source_file.write(source)

must_run_process(["clang-format", "-i", debug_source_path])

system = os.uname()[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this an argument to the compile cpp_function?
In future, it would be easier to enable cross-compilation.

include_paths = [
f"-I{TVM_PATH}/3rdparty/dmlc-core/include",
f"-I{TVM_PATH}/3rdparty/dlpack/include",
f"-I{TVM_PATH}/3rdparty/HalideIR/src",
f"-I{TVM_PATH}/include",
f"-L{TVM_PATH}/build"
]

if system == 'Darwin':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler config could be a class variable -- a dictionary to be precise.
Maybe look that up for the flags and refactor the rest as they are mostly same.
Maybe a comment explaining the special cased flags for "Darwin" could be useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this should be written in a more maintainable manner. I will see if it can be made to programmatically match up with TVM's own C++ build configuration, for example

command = [
"clang", "-std=c++14", "-shared", "-undefined", "dynamic_lookup",
"-o", lib_path,
source_path,
*include_paths,
"-ltvm"
] + flags
else:
command = [
"clang", "-std=c++14", "-shared", "-fPIC", "-o", lib_path,
source_path,
*include_paths,
"-ltvm"
] + flags

must_run_process(command)
return lib_path

def load_lib(name):
return ctypes.CDLL(name, ctypes.RTLD_GLOBAL)

def is_primitive(expr: relay.Expr):
return (isinstance(expr, relay.Function)
and expr.attrs
and expr.attrs.Primitive.value == 1)

class AoTCompiler(ExprFunctor):
"""
Takes a Relay program and converts into a Little CPP program
that can in turn be converted into C++ source code.
"""
def __init__(self, mod, tgt) -> None:
super().__init__()
self.mod = mod
self.tgt = tgt
self.engine = compile_engine.get()
self.bindings = [[]]
self.gv_map = {}

def add_binding(self, var, value):
self.bindings[-1].append((var, value))

def optimize(self, expr: Function) -> Function:
opts = tvm.transform.Sequential([
relay.transform.SimplifyInference(),
relay.transform.FuseOps(),
relay.transform.ToANormalForm()])
self.mod['main'] = expr
self.mod = opts(self.mod)
ret = self.mod['main']
return ret

def mk_primitive_op(self, func: Expr, args, output_type) -> Expr:
cc_key = compile_engine.CCacheKey(func, self.tgt)
func_hash = tvm.ir.structural_hash(func)
name = f"op_{func_hash}"
if not get_global_func(name, allow_missing=True):
jit_func = self.engine.jit(cc_key, self.tgt)
register_func(name, jit_func)
return PackedCall(name, args, [x.checked_type for x in args], output_type)

def visit_call(self, call: Expr) -> Expr:
if is_primitive(call.op):
return self.mk_primitive_op(call.op, call.args, call.checked_type)
if isinstance(call.op, Constructor):
return CPPConstructor(call.op.tag, [self.visit(arg) for arg in call.args])
assert call.attrs is None
args = [self.visit(arg) for arg in call.args]
func = self.visit(call.op)
return Invoke(func, args)

def visit_let(self, let: Expr) -> Expr:
self.bindings.append([])

while isinstance(let, Let):
cpp_value = self.visit(let.value)
self.add_binding(let.var, cpp_value)
let = let.body

bindings = self.bindings.pop()
body = self.visit(let)

return Decl(bindings, body)

def visit_var(self, var):
return var

def visit_global_var(self, gv):
if gv not in self.gv_map:
self.gv_map[gv] = "to be updated"
self.gv_map[gv] = self.visit(self.mod[gv])
return gv

def visit_function(self, func):
if is_primitive(func):
body = self.mk_primitive_op(func, func.params, func.ret_type)
return CPPFunction(func.params, body, func.checked_type.ret_type)
return CPPFunction(func.params, self.visit(func.body), func.checked_type.ret_type)

def visit_constant(self, const):
return const

def visit_if(self, i):
return CPPIf(self.visit(i.cond),
self.visit(i.true_branch),
self.visit(i.false_branch),
i.checked_type)

def visit_tuple(self, t):
return CPPTuple([self.visit(f) for f in t.fields], t.checked_type)

def visit_match(self, m):
return CPPMatch(self.visit(m.data),
[(c.lhs, self.visit(c.rhs)) for c in m.clauses],
m.checked_type)

def visit_op(self, op):
raise Exception(f'op outside of primitive: {op}')

def visit_constructor(self, ctor):
raise Exception('Constructors should be handled when visiting calls.')

def visit_tuple_getitem(self, t):
return CPPTupleGetItem(self.visit(t.tuple_value), t.index, t.checked_type)

def visit_ref_create(self, r):
return CPPRefCreate(self.visit(r.value), r.checked_type)

def visit_ref_read(self, r):
return CPPRefRead(self.visit(r.ref), r.checked_type)

def visit_ref_write(self, r):
return CPPRefWrite(self.visit(r.ref), self.visit(r.value))

_LIB_COUNTER = 1
_LIB = []

def lib_and_func_name(name):
global _LIB_COUNTER
packed_name = f'relay.aot.{name}.{_LIB_COUNTER}'
lib_name = f"librelay_aot_{_LIB_COUNTER}.so"
_LIB_COUNTER += 1
return lib_name, packed_name

def _mk_wrapper(func, ctx, constants, record_time):
Copy link
Contributor

@manupak manupak Aug 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wrapper seems to enable runtime perf. measurement; Maybe add a comment describing the functionality;
Also it would be better make record_time, default to False.
[Suggestion] make this a decorator and remove the wrapper from the implemetation and use the decorator where the performance measurement is needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may remove the time recording feature entirely for now, as it was something we included ad hoc for a single experiment

def _wrapper(*args):
new_constants = [convert(a, ctx) for a in constants]
new_args = [convert(a, ctx) for a in args]
begin = time.perf_counter()
res = func(*new_constants, *new_args)
end = time.perf_counter()
return res if not record_time else (res, end - begin)
return _wrapper

def compile_prog(func, mod, ctx, tgt, name='default', record_time=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be a good idea to make the func the "main" inside mod ? So that we only needs to pass the mod with a "main".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was definitely a flaw in the research prototype that we never took the time to correct. I agree that using the main function would be a much more sensible convention

"""Compile a Relay function into a C++ file that
implements a program with the same semantics,
which calls into TVM only for operators.

Parameters
----------
func: Expr
A Relay function to compile
(either a literal Relay function
or a GlobalVar that is in `mod`).

mod: IRModule
Module containing any functions referenced by `func`.

ctx: Context
The TVM context.

tgt: Target
The TVM target.

name: String
The name of the target binary library.

record_time: Bool
If True, the return value of the function
will include the program's execution time.

Returns
-------
result: Function
A function that, when pass in some values,
will convert them to the right format
and call the compiled func (a PackedFunc).
"""
global _LIB
if isinstance(func, GlobalVar):
func = mod[func]
assert isinstance(func, Function)
compiler = AoTCompiler(mod, tgt)
func = compiler.optimize(func)
func = compiler.visit(func)
lib_name, packed_name = lib_and_func_name(name)
constants, source_code = to_source.to_source(func, compiler.gv_map, ctx, packed_name)
lib_name = f"librelay_aot_{_LIB_COUNTER}.so"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for user to optionally give the paths for the artifacts : .so and .cc ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be a good option, I will include that in my next revision.

library_path = compile_cpp(source_code, lib_name, flags=["-O3"])
_LIB.append(load_lib(library_path))
func = get_global_func(packed_name)
return _mk_wrapper(func, ctx, constants, record_time)
51 changes: 51 additions & 0 deletions python/tvm/relay/backend/aot/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Responsible for converting function arguments into
a form that can be passed to a `PackedFunc`.
"""
import numpy as np
import tvm
from tvm import relay

def convert(a, ctx):
"""
Converts a function input `a`
(which may take constant defined in Relay, numpy arrays,
or TVM NDArrays)
into a form that can be passed to a TVM `PackedFunc`
with the given context.
"""
# convert(convert(a, tg), tg) = convert(a, tg)
while True:
if isinstance(a, int):
a = np.array(a, dtype='int32')
elif isinstance(a, np.ndarray):
a = tvm.nd.array(a, ctx)
elif isinstance(a, tvm.runtime.NDArray):
return a
elif isinstance(a, relay.Call):
assert isinstance(a.op, relay.Constructor)
a = (a.op, *a.args)
elif isinstance(a, tuple):
assert isinstance(a[0], relay.Constructor)
a = relay.backend.interpreter.ConstructorValue(
a[0].tag, [convert(arg, ctx) for arg in a[1:]], a[0])
elif isinstance(a, relay.backend.interpreter.ConstructorValue):
return a
else:
raise Exception(a, type(a))
Loading