Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of relay_to_tir target hook #8423

Merged
merged 1 commit into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ include(cmake/modules/contrib/EthosU.cmake)
include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/CODEGENC.cmake)
include(cmake/modules/contrib/DNNL.cmake)
include(cmake/modules/contrib/ExampleTargetHooks.cmake)
include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/Posit.cmake)
include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
Expand Down
19 changes: 19 additions & 0 deletions cmake/modules/contrib/ExampleTargetHooks.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

file(GLOB EXAMPLE_TARGET_HOOKS_SRC src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

everything in this PR looks great, except that i'm a little concerned we're linking in test-only C++ here into libtvm.so. Possible to do this with TVMScript in Python? i feel like we would need to e.g. add a tests/libtest/src/*.cc plus a separate cmake build target to create a .so for the code in there that links against libtvm.so, then a pytest fixture to load it in one time at the start of testing and provide the module to tests for use. and that sounds like a lot of extra ask for this pr :/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like there's precedent for this: c44b7bf/cmake/modules/contrib/CODEGENC.cmake

so let's not block this PR on that. we will likely need to come to a solution for this in the future.

list(APPEND COMPILER_SRCS ${EXAMPLE_TARGET_HOOKS_SRC})
7 changes: 7 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,13 @@ TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
*/
TVM_DLL Pass SimplifyExpr();

/*!
* \brief Run any registered RelayToTIR passes registered on the functions in a module.
*
* \return The pass.
*/
TVM_DLL Pass RelayToTIRTargetHook();

/*!
* \brief A pass for manifesting explicit memory allocations and rewriting
* specific dialects.
Expand Down
131 changes: 131 additions & 0 deletions src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>

namespace tvm {
namespace relay {
namespace contrib {
namespace example_target_hooks {

class ConvertAddToSubtract : public MixedModeMutator {
public:
explicit ConvertAddToSubtract(IRModule ir_module, Target host_target)
: ir_module_(ir_module), host_target_(host_target) {}

IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
BaseFunc main = ir_module_->Lookup(main_global_var);
Function main_func = GetRef<Function>(main.as<FunctionNode>());

// Copy everything across and mutate the body
Function mutated_main =
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
main_func->type_params, main_func->attrs, main_func->span);

ir_module_->Update(main_global_var, mutated_main);

return ir_module_;
}

private:
tir::Load LoadIndex(const tir::Buffer& buffer, const PrimExpr& index) {
return tir::Load(DataType::Float(32), buffer->data, index, tir::const_true());
}

void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) {
tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x");
tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y");
tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32));

tir::Var x_var("x", DataType::Handle());
tir::Var y_var("y", DataType::Handle());
tir::Var out_var("out", DataType::Handle());

Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", new_global_var->name_hint);
dict_attrs.Set("tir.noalias", Bool(true));

te::Var index("index", DataType::Int(32));
tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index));
tir::Stmt math_body = tir::Store(out_buffer->data, indexed_sub, index, tir::const_true());
tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body);

Map<tir::Var, tir::Buffer> buffer_map = {
{x_var, x_buffer},
{y_var, y_buffer},
{out_var, out_buffer},
};

tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It all looks pretty slick to me, you've convinced me of the 'it's just a Pass' approach. The only part not demoed here I can think of is caching. Would you be up for adding that too? Then I think we shouldn't get hung up on trying to fold any of this handling back into the te_compiler (eg by inheriting from some 'GenericLowerTE' class or something). Thanks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first thought is the Pass infra works well as a composition approach where we give a series of tools to the user rather than having them extend from specific classes - to incorporate caching I'd change the signature to something like tvm::transform::Pass(const CompileCache& cache) or similar so we can pass the cache between the intermediary passes?

Splitting the cache from te_compiler.cc and factoring it in like that seems big enough for a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree on reuse by composition rather than inheritance. I wasn't proposing to refactor caching to be shared, rather proposing to show it in your example of something that's so easy to do directly we shouldn't even worry about refactoring.

After writing my initial comment I remembered the prim<->prim shape function handling is also a bit subtle, but that's quite possibly something this sort of extension mechanism won't need to support anyway. Or, by the time it does we will have cleaned that part up enough the way to handle it will be obvious.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I'd suggest using a shared cache passed between the passes is that it encapsulates the logic of:

(Function, Target) -> PrimFunc / External Function / Empty Node

Which includes things like tracking the UniqueName of a node, we could use the IRModule itself as a form of cache in the relay_to_tir.cc but it'd likely be better to give the full cache capability to the hooks so we don't have to re-implement too much of that.

That's what I think justifies factoring the cache out and using it between the constituent passes rather than a local cache per Pass. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see it's necessary.
Within a pass it's a few lines to ensure two calls to the same attrs::kPrimitive Function are rewritten to call the same PrimFunc. We could probably even use the memoization built into ExprMutator.
Between passes there's nothing left to say -- it has all been encapsulated within the IRModule itself.
Even if it were necessary for some peculiar reason then I'd turn the conversation to figuring out how to extend IRModule or attributes or whatever to again ensure there's no special state between passes other than what is spelled out in the IRModule.
Does that make sense?

BTW LGTM for this one exactly as is since I can see the caching issue has deserved more conversation who's outcome can easily go into a follow up. Thanks for pushing on the 'just a Pass' approach, it's so much better :-)

buffer_map, DictAttrs(dict_attrs));
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_);
ir_module_->Add(new_global_var, replacement_func);
}

Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (const CallNode* call = post.as<CallNode>()) {
auto* func = call->op.as<FunctionNode>();
if (func == nullptr) {
return post;
}

auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);
if (func_name.defined() && func_name == "replace_add_with_subtract") {
// Introduce a new global var to map the function to and copy the source type
// over for InferType
GlobalVar new_global_var(func_name.value());
new_global_var->checked_type_ = func->checked_type();
ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef<Function>(func));
return Call(new_global_var, call->args, call->attrs, call->type_args, call->span);
}
}

return post;
}

public:
IRModule ir_module_;
Target host_target_;
};

transform::Pass RelayToTIR() {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[=](IRModule ir_module, transform::PassContext pass_context) {
auto relay_to_tir = ConvertAddToSubtract(ir_module, Target("c"));
return relay_to_tir.Mutate();
};
return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {});
}

} // namespace example_target_hooks
} // namespace contrib
} // namespace relay

TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU)
.set_attr<tvm::transform::Pass>("RelayToTIR",
relay::contrib::example_target_hooks::RelayToTIR());

} // namespace tvm
35 changes: 26 additions & 9 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class TECompilerImpl : public TECompilerNode {
Array<tvm::runtime::Module> ret;
std::unordered_map<std::string, std::string> cached_symbol;
std::vector<CCacheKey> cached_ext_funcs;

for (const auto& it : cache_) {
auto src_func = it.first->source_func;
ICHECK(src_func.defined());
Expand Down Expand Up @@ -383,10 +384,12 @@ class LowerTensorExprMutator : public ExprMutator {
* \brief Returns the primitive function associated with \p expr, or
* nullptr if none.
*/
Function ResolveToPrimitive(Expr expr) {
BaseFunc ResolveToPrimitive(Expr expr) {
if (const GlobalVarNode* gvn = expr.as<GlobalVarNode>()) {
BaseFunc base_func = module_->Lookup(GetRef<GlobalVar>(gvn));
return ResolveToPrimitive(base_func);
} else if (const tir::PrimFuncNode* prim_func = expr.as<tir::PrimFuncNode>()) {
return GetRef<tir::PrimFunc>(prim_func);
} else if (const VarNode* vn = expr.as<VarNode>()) {
auto itr = primitive_functions_.find(GetRef<Var>(vn));
return itr == primitive_functions_.end() ? Function() : itr->second;
Expand Down Expand Up @@ -516,10 +519,17 @@ class LowerTensorExprMutator : public ExprMutator {
Expr VisitExpr_(const LetNode* let) override {
Var var = Downcast<Var>(Mutate(let->var));
Expr value = Mutate(let->value);
Function prim_func = ResolveToPrimitive(value);
BaseFunc prim_func = ResolveToPrimitive(value);

if (prim_func.defined()) {
// Already lowered by other means, no need to mutate the Let node
if (prim_func->IsInstance<tir::PrimFuncNode>()) {
return GetRef<Let>(let);
}

// Remember let var is bound to (possibly indirectly) to a primitive.
primitive_functions_.emplace(let->var, prim_func);
Function func = Downcast<Function>(prim_func);
primitive_functions_.emplace(let->var, func);
}
Expr body = Mutate(let->body);
if (prim_func.defined()) {
Expand All @@ -537,7 +547,7 @@ class LowerTensorExprMutator : public ExprMutator {
Call expr = GetRef<Call>(call);

// Look for (indirect) calls to primitives.
Function prim_func = ResolveToPrimitive(call->op);
BaseFunc prim_func = ResolveToPrimitive(call->op);
if (!prim_func.defined()) {
// Not a call to a primitive function.
if (const FunctionNode* fn = call->op.as<FunctionNode>()) {
Expand All @@ -546,6 +556,12 @@ class LowerTensorExprMutator : public ExprMutator {
return ExprMutator::VisitExpr_(call);
}

// Already lowered by other means so we don't need to mutate
// the call
if (prim_func->IsInstance<tir::PrimFuncNode>()) {
return expr;
}

// Find the desired target device.
Target target;
if (prim_func->GetAttr<String>(attr::kCompiler).defined()) {
Expand All @@ -565,7 +581,8 @@ class LowerTensorExprMutator : public ExprMutator {
}

// Lower the primitive function for that target.
std::pair<GlobalVar, Attrs> pair = LowerFunction(prim_func, target);
Function func = Downcast<Function>(prim_func);
std::pair<GlobalVar, Attrs> pair = LowerFunction(func, target);

// Similarly transform arguments.
Array<Expr> args;
Expand Down Expand Up @@ -648,8 +665,6 @@ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) {

backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap targets,
Map<Expr, backend::StorageInfo> storage_info_map) {
CHECK_EQ(mod->functions.size(), 1)
<< "There should only be one function in the module passed to UpdateMainWorkspaceSize";
Function func = Downcast<Function>(mod->Lookup("main"));

// This is a Map<device,Map<storage_id, size>>
Expand Down Expand Up @@ -926,8 +941,10 @@ Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
PassContext ctx) {
return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn);
};
return tvm::transform::Sequential(
{tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), InferType()});

return tvm::transform::Sequential({tvm::relay::transform::RelayToTIRTargetHook(),
tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}),
InferType()});
}
} // namespace tec
} // namespace relay
Expand Down
86 changes: 86 additions & 0 deletions src/relay/transforms/target_hooks.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file target_hooks.cc
* \brief Relay passes for processing Target Hooks which have been registered on functions within
* the IRModule
*/

#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>

namespace tvm {
namespace relay {
namespace transform {

class TargetHookVisitor : public tvm::relay::MixedModeVisitor {
/*! \brief Collected pass list for all nodes */
std::vector<Pass> pass_list_;
/*! \brief Attribute map for all registered targets */
TargetKindAttrMap<Pass> target_attr_map_;

public:
TargetHookVisitor() : target_attr_map_(tvm::TargetKind::GetAttrMap<Pass>("RelayToTIR")) {}

std::vector<Pass> Visit(const IRModule& ir_mod) {
for (const auto& it : ir_mod->functions) {
const BaseFunc& base_func = it.second;
VisitExpr(base_func);
}
return pass_list_;
}

void VisitExpr_(const CallNode* call) override {
// Descend the call tree
for (auto arg : call->args) {
VisitExpr(arg);
}

if (const FunctionNode* func = call->op.as<FunctionNode>()) {
if (!func->GetAttr<String>(attr::kCompiler).defined()) {
return;
}
String code_gen_name = func->GetAttr<String>(attr::kCompiler).value();
Optional<TargetKind> target_kind = tvm::TargetKind::Get(code_gen_name);
if (!target_kind || !target_attr_map_.count(target_kind.value())) {
return;
}
Pass custom_target_pass = target_attr_map_[target_kind.value()];
if (std::find(pass_list_.begin(), pass_list_.end(), custom_target_pass) == pass_list_.end()) {
pass_list_.push_back(custom_target_pass);
}
}
}
};

Pass RelayToTIRTargetHook() {
auto pass_func = [=](IRModule mod, const PassContext& pass_ctx) {
auto target_hook_visitor = TargetHookVisitor();
std::vector<Pass> pass_list = target_hook_visitor.Visit(mod);
Sequential run_hooks(pass_list);

return run_hooks(mod);
};
return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIRTargetHook", {});
}

} // namespace transform
} // namespace relay
} // namespace tvm
Loading