Skip to content

Commit

Permalink
Change target string to Target object in the TE compiler and interpre…
Browse files Browse the repository at this point in the history
…ter (apache#8835)

* # This is a combination of 2 commits.
# This is the 1st commit message:

Initial changes

# This is the commit message #2:

Ftarget string -> Target object works!

* Fix remaining target strings

* fix bad rebase

* Fix typo

* 1 more bad rebase fix

* Lint

* typo

* Forgot to commit this

* Add TargetStrHash and Map<Target... to std::unordered_map<Target... conversion fn

* Passing most tests, yay

* remove some comments

* lint

* target-str-to-target-object

* Respond to change requests

Co-authored-by: Jared Roesch <[email protected]>
  • Loading branch information
electriclilies and jroesch authored Aug 31, 2021
1 parent 400baf2 commit 7b91e62
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 40 deletions.
2 changes: 2 additions & 0 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/target/target_kind.h>

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

Expand Down Expand Up @@ -203,5 +204,6 @@ void CheckAndUpdateHostConsistency(Map<Integer, Target>* target, Target* host);
* \param host The Target typed object for target host to be updated
*/
void CheckAndUpdateHostConsistency(Map<Target, IRModule>* target, Target* host);

} // namespace tvm
#endif // TVM_TARGET_TARGET_H_
9 changes: 4 additions & 5 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -665,11 +665,10 @@ class AOTExecutorCodegen : public MixedModeVisitor {
ret.lowered_funcs = lowered_module.per_target_module;
ret.external_mods = lowered_module.external_mods;

auto target_host_str = target_host_->str();
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_str]->Update(mod_run);
if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_]->Update(mod_run);
} else {
ret.lowered_funcs.Set(target_host_str, mod_run);
ret.lowered_funcs.Set(target_host_, mod_run);
}

std::vector<String> input_var_names(input_vars_.size());
Expand Down Expand Up @@ -774,7 +773,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
return (*it).second.first;
}

Map<String, IRModule> get_irmodule() { return this->output_.lowered_funcs; }
Map<Target, IRModule> get_irmodule() { return this->output_.lowered_funcs; }

std::shared_ptr<AOTExecutorCodegen> codegen_;
LoweredOutput output_;
Expand Down
17 changes: 9 additions & 8 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ struct ExecutorCodegen {
return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
}

Map<String, IRModule> GetIRModule() {
return CallFunc<Map<String, IRModule>>("get_irmodule", nullptr);
Map<Target, IRModule> GetIRModule() {
return CallFunc<Map<Target, IRModule>>("get_irmodule", nullptr);
}

runtime::Metadata GetMetadata() { return CallFunc<runtime::Metadata>("get_metadata"); }
Expand Down Expand Up @@ -491,8 +491,9 @@ class RelayBuildModule : public runtime::ModuleNode {
auto lowered_funcs = executor_codegen_->GetIRModule();

// No need to build for external functions.
if (lowered_funcs.find("ext_dev") != lowered_funcs.end()) {
lowered_funcs.Set("ext_dev", IRModule());
Target ext_dev("ext_dev");
if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) {
lowered_funcs.Set(ext_dev, IRModule());
}

// Generate a placeholder function that attaches linked params as its arguments.
Expand All @@ -510,11 +511,11 @@ class RelayBuildModule : public runtime::ModuleNode {
DictAttrs attrs{dict};
auto prim = tir::PrimFunc(Array<tir::Var>(), tir::SeqStmt(Array<tir::Stmt>()), VoidType(),
Map<tir::Var, tir::Buffer>(), attrs);
if (lowered_funcs.find(target_host->str()) == lowered_funcs.end()) {
lowered_funcs.Set(target_host->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
if (lowered_funcs.find(target_host) == lowered_funcs.end()) {
lowered_funcs.Set(target_host, IRModule(Map<GlobalVar, BaseFunc>({})));
}
lowered_funcs[target_host->str()]->Add(
GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), prim);
lowered_funcs[target_host]->Add(GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param),
prim);
}

// When there is no lowered_funcs due to reasons such as optimization.
Expand Down
30 changes: 18 additions & 12 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ namespace {
struct PairHash {
template <typename T1, typename T2>
std::size_t operator()(const std::pair<T1, T2>& k) const {
return std::hash<T1>()(k.first) ^ std::hash<T2>()(k.second);
return dmlc::HashCombine(std::hash<T1>()(k.first), std::hash<T2>()(k.second));
}
template <typename T2>
std::size_t operator()(const std::pair<Target, T2>& k) const {
return dmlc::HashCombine(ObjectHash()(k.first), std::hash<T2>()(k.second));
}
};

Expand Down Expand Up @@ -289,7 +293,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
// TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule.
Interpreter(IRModule mod, Map<String, IRModule> per_target_module, Device device, Target target)
Interpreter(IRModule mod, Map<Target, IRModule> per_target_module, Device device, Target target)
: mod_(mod),
per_target_module_(per_target_module),
device_(device),
Expand Down Expand Up @@ -373,7 +377,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
*/
PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array<GlobalVar>& all_tir_fn_vars,
Target target) {
std::pair<std::string, std::string> packed_func_key(target->str(), tir_fn_var->name_hint);
std::pair<Target, std::string> packed_func_key(target, tir_fn_var->name_hint);
auto packed_itr = compiled_packed_funcs_.find(packed_func_key);
if (packed_itr != compiled_packed_funcs_.end()) {
// Already compiled.
Expand All @@ -382,8 +386,11 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,

// Project out just the function(s) we need.
IRModule lowered_projected_mod;
auto mod_itr = per_target_module_.find(target->str());
ICHECK(mod_itr != per_target_module_.end())
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
per_target_module_std_map =
backend::TargetModuleMapToTargetStrModuleMap(per_target_module_);
auto mod_itr = per_target_module_std_map.find(target);
ICHECK(mod_itr != per_target_module_std_map.end())
<< "No target module for target '" << target->str() << "'";
const IRModule& target_module = (*mod_itr).second;
for (const auto& var : all_tir_fn_vars) {
Expand All @@ -407,7 +414,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PackedFunc packed_func = runtime_module.GetFunction(var->name_hint);
ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint
<< "' in compiled module for target '" << target->str() << "'";
compiled_packed_funcs_.emplace(std::make_pair(target->str(), var->name_hint), packed_func);
compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func);
}

// Return just what we need for this call.
Expand Down Expand Up @@ -874,11 +881,10 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
// Map from target key to lowered TIR functions derived from mod_.
// Note that primitives are implicitly executed on target_, while shape functions are implicitly
// executed on the default 'cpu' host. Thus this map has at most two entries.
Map<String, IRModule> per_target_module_;
Map<Target, IRModule> per_target_module_;
// Cached packed functions for the primitives and shape functions, keyed by target and
// global var name.
std::unordered_map<std::pair<std::string, std::string>, PackedFunc, PairHash>
compiled_packed_funcs_;
std::unordered_map<std::pair<Target, std::string>, PackedFunc, PairHash> compiled_packed_funcs_;
// Unique device on which primitives (but not shape functions) will be executed.
// (For simplicity we only run the interpreter on a single device.)
Device device_;
Expand All @@ -895,7 +901,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
* rewritten \p mod and target-specific modules containing bindings for all TIR primitive
* functions needed by the rewritten module.
*/
std::pair<IRModule, Map<String, IRModule>> Prepare(IRModule mod, Device device, Target target) {
std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device, Target target) {
// Run minimal transforms on module to establish invariants needed by interpreter.
transform::Sequential seq({transform::SimplifyInference(),
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
Expand Down Expand Up @@ -1014,7 +1020,7 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
// and can just eval it directly.
expr_to_eval = expr;
}
std::pair<IRModule, Map<String, IRModule>> main_and_lowered =
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered =
Prepare(mod_with_expr, device, target);
std::shared_ptr<Interpreter> intrp = std::make_shared<Interpreter>(
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device,
Expand Down Expand Up @@ -1057,7 +1063,7 @@ ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, Device device, Target target) {
std::pair<IRModule, GlobalVar> mod_and_global =
IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set);
std::pair<IRModule, Map<String, IRModule>> main_and_lowered =
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered =
Prepare(mod_and_global.first, device, target);
Interpreter intrp(
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device,
Expand Down
25 changes: 13 additions & 12 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,32 +85,33 @@ class TECompilerImpl : public TECompilerNode {
return LowerShapeFuncInternal(key)->cached_func;
}

Map<String, IRModule> GetLoweredFunctions() {
Map<String, IRModule> lowered_functions;
Map<Target, IRModule> GetLoweredFunctions() {
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
lowered_functions;
for (const auto& it : cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
auto target = source_func->target;

if (!lowered_functions.count(target->str())) {
lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
if (!lowered_functions.count(target)) {
lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
}

lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
lowered_functions[target]->Update(lowered_func->cached_func->funcs);
}

for (const auto& it : shape_func_cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
auto target = source_func->target;

if (!lowered_functions.count(target->str())) {
lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
if (!lowered_functions.count(target)) {
lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
}

lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
lowered_functions[target]->Update(lowered_func->cached_func->funcs);
}
return lowered_functions;
return backend::TargetStrModuleMapToTargetModuleMap(lowered_functions);
}

Array<tvm::runtime::Module> LowerExternalFunctions() {
Expand Down Expand Up @@ -884,7 +885,7 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) {

// Annotate the per-target functions with their target and add them to the unified module
for (const auto& kv : mod.per_target_module) {
const String target = kv.first;
const Target target = kv.first;
const IRModule target_module = kv.second;

// Right now, per-target functions are TIR functions, which don't have type definitions, so
Expand Down Expand Up @@ -926,15 +927,15 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) {
main_mod->AddTypeDef(kv.first, kv.second);
}

Map<String, IRModule> per_target_modules;
Map<Target, IRModule> per_target_modules;
for (const auto& kv : mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<relay::FunctionNode>()) {
main_mod->Add(var, func);
} else if (func->IsInstance<tir::PrimFuncNode>()) {
// Extract target
Optional<String> target = func->GetAttr<String>(tvm::attr::kTarget);
Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target) << "Target should be set at this point";

// Put the function in per_target_modules
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class TECompilerNode : public Object {
virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0;

/* Return all functions which have been lowered by the compiler, keyed by target. */
virtual Map<String, IRModule> GetLoweredFunctions() = 0;
virtual Map<Target, IRModule> GetLoweredFunctions() = 0;

/*!
* \brief Just in time compile to get a PackedFunc.
Expand Down Expand Up @@ -144,7 +144,7 @@ struct LoweredModule {
/*! \brief The module which contains the Relay code. */
IRModule main_module;
/*! \brief The module which contains per target code. */
Map<String, IRModule> per_target_module;
Map<Target, IRModule> per_target_module;
/*! \brief The external runtime modules which must be combined with the lowered code. */
Array<tvm::runtime::Module> external_mods;
// TODO(@electriclilies): THis might need to become a map
Expand Down
18 changes: 18 additions & 0 deletions src/relay/backend/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,24 @@ Array<Pass> GetPassPrefix(const Map<tvm::Integer, tvm::Target>& targets, bool is
return pass_seqs;
}

std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual>
TargetModuleMapToTargetStrModuleMap(Map<Target, IRModule> input_map) {
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> std_map;
for (auto kv : input_map) {
std_map[kv.first] = kv.second;
}
return std_map;
}

Map<Target, IRModule> TargetStrModuleMapToTargetModuleMap(
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> input_map) {
Map<Target, IRModule> tvm_map;
for (auto kv : input_map) {
tvm_map.Set(kv.first, kv.second);
}
return tvm_map;
}

} // namespace backend
} // namespace relay
} // namespace tvm
56 changes: 55 additions & 1 deletion src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type);
*/
struct LoweredOutput {
std::string graph_json;
Map<String, IRModule> lowered_funcs;
Map<Target, IRModule> lowered_funcs;
Array<tvm::runtime::Module> external_mods;
Map<String, FunctionInfo> function_metadata;
std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>> params;
Expand Down Expand Up @@ -427,6 +427,60 @@ inline bool IsCompileEngineCacheDisabled() {
*/
Array<Pass> GetPassPrefix(const Map<tvm::Integer, tvm::Target>& targets, bool is_vm);

/*! \brief Target hash function */
struct TargetStrHash {
/*!
* \brief Calculate the hash code of a Target based on the string value of the Target.
Note that this hash should NOT be used in new usecases, equality of targets based on their
value is not well-defined.
This will be removed when maps from Targets to IRModules are removed from the codebase.
* \param target The Target to hash
* \return String hash of the target
*/
size_t operator()(const Target& target) const {
return String::HashBytes(target->str().c_str(), target->str().size());
}
};

/*! \brief Target equality function based on the string value of Target
Note that this equality function should NOT be used in new usecases, equality of targets based on
their value is not well-defined. This will be removed when maps from Targets to IRModules are
removed from the codebase.*/
struct TargetStrEqual {
/*!
* \brief Check if the two Targets are equal
* \param target One Target
* \param other_target The other Target
* \return String equality of the targets
*/
const bool operator()(const Target& target, const Target& other_target) const {
TargetStrHash target_hash = TargetStrHash();
return target_hash(target) == target_hash(other_target);
}
};

/*!
* \brief Convert a Map<Target, IRModule> to std::unordered_map<Target, IRmodule, TargetStrHash,
* TargetStrEqual> Target equality is currently based on pointer equality, which is a problem since
* we have a lot of Map<Target, IRModule> in the codebase. This function converts the map to a
* version that is keyed based on string value of the Target instead. Note that once we remove
* Map<Target, IRModule>, this function will be removed.
* \param input_map The map to convert
* \return The converted map
*/
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual>
TargetModuleMapToTargetStrModuleMap(Map<Target, IRModule> input_map);

/*!
* \brief Convert a std::unordered_map<Target, IRmodule, TargetStrHash, TargetStrEqual> to
* Map<Target, IRModule> This function is a helper that undoes TargetModuleMapToTargetStr. Note that
* once we remove Map<Target, IRModule>, this function will be removed.
* \param input_map The map to convert
* \return The converted map
*/
Map<Target, IRModule> TargetStrModuleMapToTargetModuleMap(
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> input_map);

} // namespace backend
} // namespace relay
} // namespace tvm
Expand Down

0 comments on commit 7b91e62

Please sign in to comment.