Skip to content

Commit

Permalink
Remove LoweredFunc.
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Apr 4, 2020
1 parent 400a231 commit 78ae3f9
Show file tree
Hide file tree
Showing 30 changed files with 23 additions and 356 deletions.
1 change: 0 additions & 1 deletion apps/lldb/tvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __lldb_init_module(debugger, _):
"tvm::IterVarAttr",
"tvm::IterVarRelation",
"tvm::Layout",
"tir::LoweredFunc",
"tvm::Map",
"tvm::Map",
"tvm::MemoryInfo",
Expand Down
9 changes: 0 additions & 9 deletions docs/dev/codebase_walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,6 @@ After lowering is done, ``build()`` function generates target machine code from

Code generation is done by ``build_module()`` function, defined in ``python/tvm/target/codegen.py``. On the C++ side, code generation is implemented in ``src/target/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/target/codegen/codegen.cc``:

::

runtime::Module Build(const Array<LoweredFunc>& funcs,
const std::string& target) {
std::string build_f_name = "codegen.build_" + target;
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
runtime::Module m = (*bf)(funcs, target);
return m;
}


The ``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this:
Expand Down
3 changes: 1 addition & 2 deletions include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
#include <tvm/support/with.h>
#include <tvm/ir/module.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/lowered_func.h>

#include <string>
#include <vector>
Expand All @@ -44,7 +43,7 @@

namespace tvm {
/*!
* \brief Build a LoweredFunc given a schedule, args and binds
* \brief Build an IRModule given a schedule, args and binds
* \param sch The schedule to lower.
* \param args The arguments to the function.
* \param name The name of the lowered function.
Expand Down
12 changes: 0 additions & 12 deletions include/tvm/target/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/target/target.h>

#include <string>
Expand All @@ -41,17 +40,6 @@ using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;

/*!
* \brief Temporary backward compatible function to convert a list
* of LoweredFunc to a IRModule of PrimfFuncs
* \param funcs The input lowered function.
* \return The IRModule.
*
* \note This function is only used for code refactor and will be
* removed once the refactor completes.
*/
IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs);

/*!
* \brief Build a module from array of lowered function.
* \param mod The Module to be built
Expand Down
1 change: 0 additions & 1 deletion include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/function.h>
#include <tvm/tir/lowered_func.h>

#include <unordered_map>
#include <unordered_set>
Expand Down
149 changes: 0 additions & 149 deletions include/tvm/tir/lowered_func.h

This file was deleted.

9 changes: 3 additions & 6 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,7 @@ def build(inputs,
elif isinstance(inputs, tvm.IRModule):
input_mod = inputs
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError("inputs must be Schedule, LoweredFunc, list of "
"LoweredFunc, or dict of target to list of "
"LoweredFunc.")
raise ValueError("inputs must be Schedule, IRModule or dict of target to IRModule")

if not isinstance(inputs, (dict, container.Map)):
target = _target.Target.current() if target is None else target
Expand All @@ -387,9 +385,8 @@ def build(inputs,
raise ValueError("The key of inputs must be str or "
"_target.Target when inputs is dict.")
if not isinstance(mod, tvm.IRModule):
raise ValueError("inputs must be Schedule, LoweredFunc, IRModule,"
"or dict of str to list of "
"LoweredFunc.")
raise ValueError("inputs must be Schedule, IRModule,"
"or dict of str to IRModule.")

if not target_host:
for tar, _ in target_input_mod.items():
Expand Down
12 changes: 5 additions & 7 deletions python/tvm/relay/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def lower(sch, inputs, func_name, source_func):
Returns
-------
lowered_funcs : List[tvm.LoweredFunc]
mod : tvm.IRModule
The result of lowering.
"""
# pylint: disable=broad-except, import-outside-toplevel
Expand All @@ -59,15 +59,13 @@ def lower(sch, inputs, func_name, source_func):


@tvm._ffi.register_func("relay.backend.build")
def build(funcs, target, target_host=None):
def build(mod, target, target_host=None):
"""Backend build function.
Parameters
----------
funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
A list of lowered functions or dictionary mapping from targets to
lowered functions.
mod : tvm.IRModule or Dict[str, tvm.IRModule]
Input module
target : tvm.Target
The target to run the code on.
Expand All @@ -82,7 +80,7 @@ def build(funcs, target, target_host=None):
"""
if target_host == "":
target_host = None
return tvm.driver.build(funcs, target=target, target_host=target_host)
return tvm.driver.build(mod, target=target, target_host=target_host)


@tvm._ffi.register_func("relay._tensor_value_repr")
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, mod, target):
self._get_graph_json = self._mod["get_graph_json"]
self._list_params_name = self._mod["list_params_name"]
self._get_param_by_name = self._mod["get_param_by_name"]
self._get_lowered_funcs = self._mod["get_lowered_funcs"]
self._get_irmodule = self._mod["get_irmodule"]
self._setup(mod, target)

def _setup(self, mod, target):
Expand All @@ -74,14 +74,14 @@ def codegen(self, func):
-------
graph_json : str
The graph json that can be consumed by runtime.
lowered_funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
mod : IRModule or Dict[str, IRModule]
The lowered functions.
params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
"""
self._codegen(func)
graph_json = self._get_graph_json()
lowered_func = self._get_lowered_funcs()
lowered_func = self._get_irmodule()
param_names = self._list_params_name()
params = {}
for name in param_names:
Expand Down
10 changes: 2 additions & 8 deletions python/tvm/target/build_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
import tvm.ir

from tvm.runtime import Object
from tvm.ir import container
from tvm.tir import Stmt
from tvm.tir.stmt import LoweredFunc
from . import _ffi_api


Expand All @@ -48,17 +46,13 @@ def decorate(self, func):
def dump(*args, **kwargs):
"""dump function"""
retv = func(*args, **kwargs)
if not isinstance(retv, (Stmt, LoweredFunc, container.Array)):
if not isinstance(retv, (Stmt,)):
return retv
fname = func.func_name if hasattr(func, 'func_name') else func.__name__
pname = str(self._pass_id) + "_" + fname + "_ir.cc"
with open(pname, "a") as f:
out = retv.body if isinstance(retv, LoweredFunc) else retv
out = retv
f.write(str(out))
if isinstance(retv, container.Array):
for x in retv:
out = x.body if isinstance(x, LoweredFunc) else x
f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
self._pass_id += 1
return retv
return dump
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list

from .function import PrimFunc

Expand Down
8 changes: 0 additions & 8 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,14 +385,6 @@ def __init__(self, func, value_index, dtype, bounds):
_ffi_api.Prefetch, func, value_index, dtype, bounds)


@tvm._ffi.register_object
class LoweredFunc(Object):
"""Represent a LoweredFunc in TVM."""
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2


def stmt_seq(*args):
"""Make sequence of statements
Expand Down
1 change: 0 additions & 1 deletion src/contrib/hybrid/codegen_hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/te/schedule.h>
#include <map>
#include <string>
Expand Down
1 change: 0 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ namespace tvm {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
using tir::LoweredFunc;

bool LLVMEnabled() {
const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm");
Expand Down
Loading

0 comments on commit 78ae3f9

Please sign in to comment.