diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 9c1fe55749e4..deec662e74ad 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -31,6 +31,7 @@ #include #include +#include #include #include @@ -203,5 +204,6 @@ void CheckAndUpdateHostConsistency(Map* target, Target* host); * \param host The Target typed object for target host to be updated */ void CheckAndUpdateHostConsistency(Map* target, Target* host); + } // namespace tvm #endif // TVM_TARGET_TARGET_H_ diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 942bc0d1d44a..2b88f0489321 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -669,11 +669,10 @@ class AOTExecutorCodegen : public ExprVisitor { ret.lowered_funcs = lowered_module.per_target_module; ret.external_mods = lowered_module.external_mods; - auto target_host_str = target_host_->str(); - if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { - ret.lowered_funcs[target_host_str]->Update(mod_run); + if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) { + ret.lowered_funcs[target_host_]->Update(mod_run); } else { - ret.lowered_funcs.Set(target_host_str, mod_run); + ret.lowered_funcs.Set(target_host_, mod_run); } std::vector input_var_names(input_vars_.size()); @@ -778,7 +777,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { return (*it).second.first; } - Map get_irmodule() { return this->output_.lowered_funcs; } + Map get_irmodule() { return this->output_.lowered_funcs; } std::shared_ptr codegen_; LoweredOutput output_; diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index b2b73e9bad02..69dced36295e 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -92,8 +92,8 @@ struct ExecutorCodegen { return CallFunc>("get_external_modules", nullptr); } - Map GetIRModule() { - return CallFunc>("get_irmodule", nullptr); + Map GetIRModule() { + return CallFunc>("get_irmodule", nullptr); } runtime::Metadata GetMetadata() { return CallFunc("get_metadata"); } @@ -491,8 +491,9 @@ class RelayBuildModule : public runtime::ModuleNode { auto lowered_funcs = executor_codegen_->GetIRModule(); // No need to build for external functions. - if (lowered_funcs.find("ext_dev") != lowered_funcs.end()) { - lowered_funcs.Set("ext_dev", IRModule()); + Target ext_dev("ext_dev"); + if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) { + lowered_funcs.Set(ext_dev, IRModule()); } // Generate a placeholder function that attaches linked params as its arguments. @@ -510,11 +511,11 @@ class RelayBuildModule : public runtime::ModuleNode { DictAttrs attrs{dict}; auto prim = tir::PrimFunc(Array(), tir::SeqStmt(Array()), VoidType(), Map(), attrs); - if (lowered_funcs.find(target_host->str()) == lowered_funcs.end()) { - lowered_funcs.Set(target_host->str(), IRModule(Map({}))); + if (lowered_funcs.find(target_host) == lowered_funcs.end()) { + lowered_funcs.Set(target_host, IRModule(Map({}))); } - lowered_funcs[target_host->str()]->Add( - GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), prim); + lowered_funcs[target_host]->Add(GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), + prim); } // When there is no lowered_funcs due to reasons such as optimization. diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index af2cbae1f72d..76b6f9186eb5 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -53,7 +53,11 @@ namespace { struct PairHash { template std::size_t operator()(const std::pair& k) const { - return std::hash()(k.first) ^ std::hash()(k.second); + return dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); + } + template + std::size_t operator()(const std::pair& k) const { + return dmlc::HashCombine(ObjectHash()(k.first), std::hash()(k.second)); } }; @@ -289,7 +293,7 @@ class Interpreter : public ExprFunctor, PatternFunctor { public: // TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule. - Interpreter(IRModule mod, Map per_target_module, Device device, Target target) + Interpreter(IRModule mod, Map per_target_module, Device device, Target target) : mod_(mod), per_target_module_(per_target_module), device_(device), @@ -373,7 +377,7 @@ class Interpreter : public ExprFunctor, */ PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array& all_tir_fn_vars, Target target) { - std::pair packed_func_key(target->str(), tir_fn_var->name_hint); + std::pair packed_func_key(target, tir_fn_var->name_hint); auto packed_itr = compiled_packed_funcs_.find(packed_func_key); if (packed_itr != compiled_packed_funcs_.end()) { // Already compiled. @@ -382,8 +386,11 @@ class Interpreter : public ExprFunctor, // Project out just the function(s) we need. IRModule lowered_projected_mod; - auto mod_itr = per_target_module_.find(target->str()); - ICHECK(mod_itr != per_target_module_.end()) + std::unordered_map + per_target_module_std_map = + backend::TargetModuleMapToTargetStrModuleMap(per_target_module_); + auto mod_itr = per_target_module_std_map.find(target); + ICHECK(mod_itr != per_target_module_std_map.end()) << "No target module for target '" << target->str() << "'"; const IRModule& target_module = (*mod_itr).second; for (const auto& var : all_tir_fn_vars) { @@ -407,7 +414,7 @@ class Interpreter : public ExprFunctor, PackedFunc packed_func = runtime_module.GetFunction(var->name_hint); ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint << "' in compiled module for target '" << target->str() << "'"; - compiled_packed_funcs_.emplace(std::make_pair(target->str(), var->name_hint), packed_func); + compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func); } // Return just what we need for this call. @@ -874,11 +881,10 @@ class Interpreter : public ExprFunctor, // Map from target key to lowered TIR functions derived from mod_. // Note that primitives are implicitly executed on target_, while shape functions are implicitly // executed on the default 'cpu' host. Thus this map has at most two entries. - Map per_target_module_; + Map per_target_module_; // Cached packed functions for the primitives and shape functions, keyed by target and // global var name. - std::unordered_map, PackedFunc, PairHash> - compiled_packed_funcs_; + std::unordered_map, PackedFunc, PairHash> compiled_packed_funcs_; // Unique device on which primitives (but not shape functions) will be executed. // (For simplicity we only run the interpreter on a single device.) Device device_; @@ -895,7 +901,7 @@ class Interpreter : public ExprFunctor, * rewritten \p mod and target-specific modules containing bindings for all TIR primitive * functions needed by the rewritten module. */ -std::pair> Prepare(IRModule mod, Device device, Target target) { +std::pair> Prepare(IRModule mod, Device device, Target target) { // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq({transform::SimplifyInference(), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' @@ -1014,7 +1020,7 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De // and can just eval it directly. expr_to_eval = expr; } - std::pair> main_and_lowered = + std::pair> main_and_lowered = Prepare(mod_with_expr, device, target); std::shared_ptr intrp = std::make_shared( /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, @@ -1057,7 +1063,7 @@ ObjectRef Eval(Expr expr, Map type_definitions, std::unordered_set import_set, Device device, Target target) { std::pair mod_and_global = IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); - std::pair> main_and_lowered = + std::pair> main_and_lowered = Prepare(mod_and_global.first, device, target); Interpreter intrp( /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 71ac752ec680..06d862b781e1 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -85,18 +85,19 @@ class TECompilerImpl : public TECompilerNode { return LowerShapeFuncInternal(key)->cached_func; } - Map GetLoweredFunctions() { - Map lowered_functions; + Map GetLoweredFunctions() { + std::unordered_map + lowered_functions; for (const auto& it : cache_) { auto source_func = it.first; auto lowered_func = it.second; auto target = source_func->target; - if (!lowered_functions.count(target->str())) { - lowered_functions.Set(target->str(), IRModule(Map({}))); + if (!lowered_functions.count(target)) { + lowered_functions[target] = IRModule(Map({})); } - lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + lowered_functions[target]->Update(lowered_func->cached_func->funcs); } for (const auto& it : shape_func_cache_) { @@ -104,13 +105,13 @@ class TECompilerImpl : public TECompilerNode { auto lowered_func = it.second; auto target = source_func->target; - if (!lowered_functions.count(target->str())) { - lowered_functions.Set(target->str(), IRModule(Map({}))); + if (!lowered_functions.count(target)) { + lowered_functions[target] = IRModule(Map({})); } - lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + lowered_functions[target]->Update(lowered_func->cached_func->funcs); } - return lowered_functions; + return backend::TargetStrModuleMapToTargetModuleMap(lowered_functions); } Array LowerExternalFunctions() { @@ -884,7 +885,7 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) { // Annotate the per-target functions with their target and add them to the unified module for (const auto& kv : mod.per_target_module) { - const String target = kv.first; + const Target target = kv.first; const IRModule target_module = kv.second; // Right now, per-target functions are TIR functions, which don't have type definitions, so @@ -926,7 +927,7 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { main_mod->AddTypeDef(kv.first, kv.second); } - Map per_target_modules; + Map per_target_modules; for (const auto& kv : mod->functions) { const GlobalVar& var = kv.first; const BaseFunc& func = kv.second; @@ -934,7 +935,7 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { main_mod->Add(var, func); } else if (func->IsInstance()) { // Extract target - Optional target = func->GetAttr(tvm::attr::kTarget); + Optional target = func->GetAttr(tvm::attr::kTarget); ICHECK(target) << "Target should be set at this point"; // Put the function in per_target_modules diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index e9cfb0d62e66..65ba67ac7e1b 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -97,7 +97,7 @@ class TECompilerNode : public Object { virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0; /* Return all functions which have been lowered by the compiler, keyed by target. */ - virtual Map GetLoweredFunctions() = 0; + virtual Map GetLoweredFunctions() = 0; /*! * \brief Just in time compile to get a PackedFunc. @@ -144,7 +144,7 @@ struct LoweredModule { /*! \brief The module which contains the Relay code. */ IRModule main_module; /*! \brief The module which contains per target code. */ - Map per_target_module; + Map per_target_module; /*! \brief The external runtime modules which must be combined with the lowered code. */ Array external_mods; // TODO(@electriclilies): THis might need to become a map diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 4b4844599e29..ea0ab093aa1d 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -187,6 +187,24 @@ Array GetPassPrefix(const Map& targets, bool is return pass_seqs; } +std::unordered_map +TargetModuleMapToTargetStrModuleMap(Map input_map) { + std::unordered_map std_map; + for (auto kv : input_map) { + std_map[kv.first] = kv.second; + } + return std_map; +} + +Map TargetStrModuleMapToTargetModuleMap( + std::unordered_map input_map) { + Map tvm_map; + for (auto kv : input_map) { + tvm_map.Set(kv.first, kv.second); + } + return tvm_map; +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index a0c7a5aad26d..cf8a2dd4b8e0 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -139,7 +139,7 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type); */ struct LoweredOutput { std::string graph_json; - Map lowered_funcs; + Map lowered_funcs; Array external_mods; Map function_metadata; std::unordered_map> params; @@ -427,6 +427,60 @@ inline bool IsCompileEngineCacheDisabled() { */ Array GetPassPrefix(const Map& targets, bool is_vm); +/*! \brief Target hash function */ +struct TargetStrHash { + /*! + * \brief Calculate the hash code of a Target based on the string value of the Target. + Note that this hash should NOT be used in new usecases, equality of targets based on their + value is not well-defined. + This will be removed when maps from Targets to IRModules are removed from the codebase. + * \param target The Target to hash + * \return String hash of the target + */ + size_t operator()(const Target& target) const { + return String::HashBytes(target->str().c_str(), target->str().size()); + } +}; + +/*! \brief Target equality function based on the string value of Target +Note that this equality function should NOT be used in new usecases, equality of targets based on +their value is not well-defined. This will be removed when maps from Targets to IRModules are +removed from the codebase.*/ +struct TargetStrEqual { + /*! + * \brief Check if the two Targets are equal + * \param target One Target + * \param other_target The other Target + * \return String equality of the targets + */ + const bool operator()(const Target& target, const Target& other_target) const { + TargetStrHash target_hash = TargetStrHash(); + return target_hash(target) == target_hash(other_target); + } +}; + +/*! + * \brief Convert a Map to std::unordered_map Target equality is currently based on pointer equality, which is a problem since + * we have a lot of Map in the codebase. This function converts the map to a + * version that is keyed based on string value of the Target instead. Note that once we remove + * Map, this function will be removed. + * \param input_map The map to convert + * \return The converted map + */ +std::unordered_map +TargetModuleMapToTargetStrModuleMap(Map input_map); + +/*! + * \brief Convert a std::unordered_map to + * Map This function is a helper that undoes TargetModuleMapToTargetStr. Note that + * once we remove Map, this function will be removed. + * \param input_map The map to convert + * \return The converted map + */ +Map TargetStrModuleMapToTargetModuleMap( + std::unordered_map input_map); + } // namespace backend } // namespace relay } // namespace tvm