Skip to content

Commit

Permalink
Refactor the compile engine into a cleaner interface. (#7518)
Browse files Browse the repository at this point in the history
Duplicate the CompileEngine interface.

Refactor the graph_runtime_codegen to invoke the new LowerTE pass

More changes

Things appear to be working

Some tracing to get Relay code to flow through too.

Disable some assertions as exp.

Tweak printing for now

Fix a few bugs: (#13)

1. Don't add relay main function to list of lowered TIR functions
2. Don't skip visiting call to relay function in graph runtime codegen

Remove debug prints.

Start refactoring

Split out shared data structures

Fix implicit duplicate decl of IsDynamic

Clean up handling of name + global prim fn

Clean up the code and debug issue introduced by previous hack

Clean up the debugging

Do C++ lint clean up

Update src/relay/backend/graph_executor_codegen.cc

Co-authored-by: Chris Sullivan <[email protected]>

Clean up handling of external functions

Add more error messages

More clean up

Update src/runtime/graph_executor/graph_executor.cc

Co-authored-by: Chris Sullivan <[email protected]>

Update src/runtime/graph_executor/graph_executor.cc

Co-authored-by: Chris Sullivan <[email protected]>

Update src/relay/backend/te_compiler.h

Co-authored-by: Haichen Shen <[email protected]>

Update src/relay/backend/te_compiler.h

Co-authored-by: Haichen Shen <[email protected]>

Fix

CR

More CR

Format

Fix lowering path for C++

Fix tests

Remove uncessary change

Clean up a few more things

CI fix

Fix the default context

Fix

Fix broken test cases

Update

Fix

WIP

Clean up storage data structures

WIP

WIP

Fix build errors

Remove TVMLower

Fix lint

Lint again

fix black

Move UpdateMainWorkspaceSize into te_compiler.cc

Fix link errors

Formatting

Change UpdateMainWorkspaceSize to return Map<String, FunctionInfo>

Workaround for GCC 5 error caused by enums in maps (GCC 5 is on i386 CI)

Testing how functions should be named

Lint

Change how function metadata is updated

Attempt to update aot_executor_codegen to use new StaticMemoryPlan instead of storage_device_map

Pass memory plan through LowerTE into UpdateMainWorkspaceSize so that we don't need to run GraphPlanMemory an extra time

Fix return in UpdateMainWorkspaceSize

Lint

Try to fix UpdateMainWorkspaceSize

Fix construction of static memory plan

Clean up code while debugging

Adding UpdateWorkspaceSize back

Add closure + call to UpdateFunctionMetadata (WIP)

UpdateFunctionMetadata builds; weird error with device ctx map though. Not sure if it came from this change or something else

Add some debugging of UpdateMainWorkspaceSize

Starting to move UpdateFunctionMetadata call to use process_fn infra

UWhat target should be passed to UpdateFunctionMetadata?

UpdateFunctionMetadata is not workinggg

Added some comments about UpdateFunctionMetadata for Jared

Fix the creation of function metadata

Try another stab at cleaning up the information

Fix

Port StorageInfo and StaticMemoryPlan data structure (#8297)

Restoring reshape opt

Fix tests

Caught a nasty typo from Lily, Map::Set does not mutate

Format

Disable stupid Google style warning

Rebase cleanup

Formatting

Add docstring for storage info

Black

Post rebase fix

Remove prints

Disable assert that doesn't make sense for now

Fix lint

Add copying attrs from relay node to graph node; still need to figure out how to do this in the case of global vars

Work with Lily to fix graph attrs

Try to figure out where extra arguments are coming from; fix merge

passes the profiling test

Clean up

Fix profile test

Remove debugging

Add attributes for BYOC uTVM case

Format

Dumb typo

Another fix for byoc

Format

Fix last 3 failing tests

Format

Fix final two test cases

Format

Fix lint

Fix again

Fix

Fix auto scheduler code

Fix issue

Address CR comment

Format

Co-authored-by: Jared Roesch <[email protected]>
  • Loading branch information
csullivan and jroesch committed Jul 8, 2021
1 parent 8fb4cdf commit 9c66587
Show file tree
Hide file tree
Showing 29 changed files with 2,340 additions and 1,158 deletions.
12 changes: 12 additions & 0 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
}
};

/*!
* \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR.
*/
struct TIRCallAttrs : public tvm::AttrsNode<TIRCallAttrs> {
/*! \brief The metadata attached to the call node. */
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") {
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
10 changes: 10 additions & 0 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def auto_schedule_topi(func_name, outs):
A tuned schedule or none (if not tuned) in the final build mode;
None in the tracing mode so that the fallback topi schedule will be used.
"""

# pylint: disable=import-outside-toplevel
from tvm.auto_scheduler.measure import (
prepare_input_map,
Expand Down Expand Up @@ -376,6 +377,15 @@ def auto_schedule_topi(func_name, outs):
return schedule


@tvm._ffi.register_func("auto_scheduler.relay_integration.te_compiler_update_weights")
def te_compiler_update_weights(function_weights):
"""A callback for updating the weights of extracted tasks."""
env = TracingEnvironment.current
if env is not None:
for key in env.wkl_key_to_weight:
env.wkl_key_to_weight[key] = function_weights[key[0]]


def tensor_no_check_call(self, *indices):
"""An indexing function without any check.
This is the same as `tvm.te.Tensor::__call__` except that the safety
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def pre_tune(self, task_scheduler, task_id):

# overall info
if all(cost < 1e9 for cost in task_scheduler.best_costs):
total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3)
total_latency_str = "%.3f" % (task_scheduler.cur_score.value * 1e3)
else:
total_latency_str = "-"
print(
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def dump(self):
res += "------------------------------------\n"
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.func_name)
res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint)
res += "----relay function----\n"
res += k.source_func.astext() + "\n"
res += "----tir function----- \n"
Expand All @@ -444,7 +444,7 @@ def dump(self):
res += "------------------------------------\n"
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.func_name)
res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint)
res += "----relay function----\n"
res += k.source_func.astext() + "\n"
res += "----tir function----- \n"
Expand Down
24 changes: 23 additions & 1 deletion python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import tvm._ffi
from tvm._ffi import base as _base
from tvm.runtime import NDArray, ndarray as _nd
from tvm.ir import RelayExpr, GlobalVar
from tvm.ir import RelayExpr, GlobalVar, Node

from .base import RelayNode
from . import _ffi_api
Expand Down Expand Up @@ -538,3 +538,25 @@ def bind(expr, binds):
The expression or function after binding.
"""
return _ffi_api.Bind(expr, binds)


@tvm._ffi.register_object("relay.StorageInfo")
class StorageInfo(Node):
"""StorageInfo
The static storage information produced by memory planning.
Contains the storage ids where expressions are stored, the
type of the "virtual devices" the expressions are stored on,
and the sizes of each storage element."""

@property
def storage_ids(self):
return _ffi_api.StorageInfoStorageIds(self)

@property
def device_types(self):
return _ffi_api.StorageInfoDeviceTypes(self)

@property
def storage_sizes(self):
return _ffi_api.StorageInfoStorageSizes(self)
10 changes: 7 additions & 3 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,14 +437,18 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
}

if (target->kind->device_type == kDLCPU && target_host == target) {
ICHECK(mdevice->functions.empty()) << "No device code should be generated when target "
<< "and host_target are both llvm target."
<< "\n";
// TODO(@jroesch): This check is no longer true we need to figure out if we care about this.
// We need to relax this check for just TIR functions.
// ICHECK(mdevice->functions.empty()) << "No device code should be generated when target "
// << "and host_target are both llvm target."
// << "\n";
}

return {mhost, mdevice};
}

// Can we make this take one annotated IRModule?
//
// Build for heterogeneous execution.
runtime::Module build(const Map<Target, IRModule>& inputs_arg, const Target& target_host_arg) {
auto pass_ctx = transform::PassContext::Current();
Expand Down
18 changes: 9 additions & 9 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ class AOTExecutorCodegen : public ExprVisitor {
fi_node->tir_primfuncs.Set(primfunc_target, primfunc);
fi_node->relay_primfuncs.Set(primfunc_target, relay_func);
}
function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node));
function_metadata_.Set(cfunc->prim_fn_var->name_hint, FunctionInfo(fi_node));
}

void VisitExpr_(const CallNode* op) override {
Expand All @@ -465,20 +465,18 @@ class AOTExecutorCodegen : public ExprVisitor {
<< "(i.e functions composed of fusable operator invocations)";
}

auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Target target;

// Handle external function
if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = Target("ext_dev");
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key, mod_name_);
CCacheKey key = CCacheKey(func, target);
CachedFunc ext_func = compile_engine_->Lower(key, mod_name_);
ICHECK(ext_func.defined()) << "External function is not defined.";
UpdateConstants(func, &params_);

// Generate the TIR function call
CreateFuncCall(GetRef<Call>(op), ext_func->func_name);
CreateFuncCall(GetRef<Call>(op), ext_func->prim_fn_var->name_hint);
return;
}

Expand All @@ -503,8 +501,10 @@ class AOTExecutorCodegen : public ExprVisitor {
}
target = targets_[call_dev_type];
}
CCacheKey key = (*pf0)(func, target);
CachedFunc lowered_func = (*pf1)(compile_engine_, key, mod_name_);

CCacheKey key = CCacheKey(func, target);
CachedFunc lowered_func = compile_engine_->Lower(key, mod_name_);

if (!lowered_funcs_.count(target->str())) {
lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({}));
}
Expand All @@ -513,7 +513,7 @@ class AOTExecutorCodegen : public ExprVisitor {
UpdateFunctionMetadata(lowered_func, func, target);

// Generate the TIR function call
CreateFuncCall(GetRef<Call>(op), lowered_func->func_name);
CreateFuncCall(GetRef<Call>(op), lowered_func->prim_fn_var->name_hint);
}

void VisitExpr_(const VarNode* op) override {
Expand Down
Loading

0 comments on commit 9c66587

Please sign in to comment.