-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of relay_to_tir target hook #8423
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
file(GLOB EXAMPLE_TARGET_HOOKS_SRC src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc) | ||
list(APPEND COMPILER_SRCS ${EXAMPLE_TARGET_HOOKS_SRC}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
|
||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
#include <tvm/relay/expr_functor.h> | ||
#include <tvm/relay/transform.h> | ||
#include <tvm/tir/buffer.h> | ||
#include <tvm/tir/builtin.h> | ||
#include <tvm/tir/expr.h> | ||
#include <tvm/tir/function.h> | ||
#include <tvm/tir/op.h> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
namespace contrib { | ||
namespace example_target_hooks { | ||
|
||
class ConvertAddToSubtract : public MixedModeMutator { | ||
public: | ||
explicit ConvertAddToSubtract(IRModule ir_module, Target host_target) | ||
: ir_module_(ir_module), host_target_(host_target) {} | ||
|
||
IRModule Mutate() { | ||
GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); | ||
BaseFunc main = ir_module_->Lookup(main_global_var); | ||
Function main_func = GetRef<Function>(main.as<FunctionNode>()); | ||
|
||
// Copy everything across and mutate the body | ||
Function mutated_main = | ||
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, | ||
main_func->type_params, main_func->attrs, main_func->span); | ||
|
||
ir_module_->Update(main_global_var, mutated_main); | ||
|
||
return ir_module_; | ||
} | ||
|
||
private: | ||
tir::Load LoadIndex(const tir::Buffer& buffer, const PrimExpr& index) { | ||
return tir::Load(DataType::Float(32), buffer->data, index, tir::const_true()); | ||
} | ||
|
||
void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) { | ||
tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x"); | ||
tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y"); | ||
tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32)); | ||
|
||
tir::Var x_var("x", DataType::Handle()); | ||
tir::Var y_var("y", DataType::Handle()); | ||
tir::Var out_var("out", DataType::Handle()); | ||
|
||
Map<String, ObjectRef> dict_attrs; | ||
dict_attrs.Set("global_symbol", new_global_var->name_hint); | ||
dict_attrs.Set("tir.noalias", Bool(true)); | ||
|
||
te::Var index("index", DataType::Int(32)); | ||
tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index)); | ||
tir::Stmt math_body = tir::Store(out_buffer->data, indexed_sub, index, tir::const_true()); | ||
tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body); | ||
|
||
Map<tir::Var, tir::Buffer> buffer_map = { | ||
{x_var, x_buffer}, | ||
{y_var, y_buffer}, | ||
{out_var, out_buffer}, | ||
}; | ||
|
||
tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It all looks pretty slick to me, you've convinced me of the 'it's just a Pass' approach. The only part not demoed here I can think of is caching. Would you be up for adding that too? Then I think we shouldn't get hung up on trying to fold any of this handling back into the te_compiler (eg by inheriting from some 'GenericLowerTE' class or something). Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My first thought is the Splitting the cache from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree on reuse by composition rather than inheritance. I wasn't proposing to refactor caching to be shared, rather proposing to show it in your example of something that's so easy to do directly we shouldn't even worry about refactoring. After writing my initial comment I remembered the prim<->prim shape function handling is also a bit subtle, but that's quite possibly something this sort of extension mechanism won't need to support anyway. Or, by the time it does we will have cleaned that part up enough the way to handle it will be obvious. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason I'd suggest using a shared cache passed between the passes is that it encapsulates the logic of:
Which includes things like tracking the That's what I think justifies factoring the cache out and using it between the constituent passes rather than a local cache per There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see it's necessary. BTW LGTM for this one exactly as is since I can see the caching issue has deserved more conversation who's outcome can easily go into a follow up. Thanks for pushing on the 'just a Pass' approach, it's so much better :-) |
||
buffer_map, DictAttrs(dict_attrs)); | ||
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); | ||
ir_module_->Add(new_global_var, replacement_func); | ||
} | ||
|
||
Expr Rewrite_(const CallNode* pre, const Expr& post) override { | ||
if (const CallNode* call = post.as<CallNode>()) { | ||
auto* func = call->op.as<FunctionNode>(); | ||
if (func == nullptr) { | ||
return post; | ||
} | ||
|
||
auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol); | ||
if (func_name.defined() && func_name == "replace_add_with_subtract") { | ||
// Introduce a new global var to map the function to and copy the source type | ||
// over for InferType | ||
GlobalVar new_global_var(func_name.value()); | ||
new_global_var->checked_type_ = func->checked_type(); | ||
ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef<Function>(func)); | ||
return Call(new_global_var, call->args, call->attrs, call->type_args, call->span); | ||
} | ||
} | ||
|
||
return post; | ||
} | ||
|
||
public: | ||
IRModule ir_module_; | ||
Target host_target_; | ||
}; | ||
|
||
transform::Pass RelayToTIR() { | ||
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func = | ||
[=](IRModule ir_module, transform::PassContext pass_context) { | ||
auto relay_to_tir = ConvertAddToSubtract(ir_module, Target("c")); | ||
return relay_to_tir.Mutate(); | ||
}; | ||
return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {}); | ||
} | ||
|
||
} // namespace example_target_hooks | ||
} // namespace contrib | ||
} // namespace relay | ||
|
||
TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) | ||
.set_attr<tvm::transform::Pass>("RelayToTIR", | ||
relay::contrib::example_target_hooks::RelayToTIR()); | ||
|
||
} // namespace tvm |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file target_hooks.cc | ||
* \brief Relay passes for processing Target Hooks which have been registered on functions within | ||
* the IRModule | ||
*/ | ||
|
||
#include <tvm/relay/expr_functor.h> | ||
#include <tvm/relay/transform.h> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
namespace transform { | ||
|
||
class TargetHookVisitor : public tvm::relay::MixedModeVisitor { | ||
/*! \brief Collected pass list for all nodes */ | ||
std::vector<Pass> pass_list_; | ||
/*! \brief Attribute map for all registered targets */ | ||
TargetKindAttrMap<Pass> target_attr_map_; | ||
|
||
public: | ||
TargetHookVisitor() : target_attr_map_(tvm::TargetKind::GetAttrMap<Pass>("RelayToTIR")) {} | ||
|
||
std::vector<Pass> Visit(const IRModule& ir_mod) { | ||
for (const auto& it : ir_mod->functions) { | ||
const BaseFunc& base_func = it.second; | ||
VisitExpr(base_func); | ||
} | ||
return pass_list_; | ||
} | ||
|
||
void VisitExpr_(const CallNode* call) override { | ||
// Descend the call tree | ||
for (auto arg : call->args) { | ||
VisitExpr(arg); | ||
} | ||
|
||
if (const FunctionNode* func = call->op.as<FunctionNode>()) { | ||
if (!func->GetAttr<String>(attr::kCompiler).defined()) { | ||
return; | ||
} | ||
String code_gen_name = func->GetAttr<String>(attr::kCompiler).value(); | ||
Optional<TargetKind> target_kind = tvm::TargetKind::Get(code_gen_name); | ||
if (!target_kind || !target_attr_map_.count(target_kind.value())) { | ||
return; | ||
} | ||
Pass custom_target_pass = target_attr_map_[target_kind.value()]; | ||
if (std::find(pass_list_.begin(), pass_list_.end(), custom_target_pass) == pass_list_.end()) { | ||
pass_list_.push_back(custom_target_pass); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
Pass RelayToTIRTargetHook() { | ||
auto pass_func = [=](IRModule mod, const PassContext& pass_ctx) { | ||
auto target_hook_visitor = TargetHookVisitor(); | ||
std::vector<Pass> pass_list = target_hook_visitor.Visit(mod); | ||
Sequential run_hooks(pass_list); | ||
|
||
return run_hooks(mod); | ||
}; | ||
return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIRTargetHook", {}); | ||
} | ||
|
||
} // namespace transform | ||
} // namespace relay | ||
} // namespace tvm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
everything in this PR looks great, except that i'm a little concerned we're linking in test-only C++ here into
libtvm.so
. Possible to do this with TVMScript in Python? i feel like we would need to e.g. add a tests/libtest/src/*.cc plus a separatecmake
build target to create a.so
for the code in there that links againstlibtvm.so
, then a pytest fixture to load it in one time at the start of testing and provide the module to tests for use. and that sounds like a lot of extra ask for this pr :/There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @jroesch @tqchen @junrushao1994
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like there's precedent for this: c44b7bf/cmake/modules/contrib/CODEGENC.cmake
so let's not block this PR on that. we will likely need to come to a solution for this in the future.