From 0e9e3069dde8501f7deb50a538b9c1aceb0d5ba8 Mon Sep 17 00:00:00 2001 From: JiaKe Date: Wed, 14 Dec 2022 11:02:34 +0000 Subject: [PATCH] Add SMJ support (#97) --- velox/substrait/SubstraitParser.cpp | 31 ++++++--- velox/substrait/SubstraitParser.h | 11 ++++ velox/substrait/SubstraitToVeloxPlan.cpp | 64 +++++++++---------- .../SubstraitToVeloxPlanValidator.cpp | 13 ++++ 4 files changed, 75 insertions(+), 44 deletions(-) diff --git a/velox/substrait/SubstraitParser.cpp b/velox/substrait/SubstraitParser.cpp index 2fd3aafb4c6e..5119cc1298ae 100644 --- a/velox/substrait/SubstraitParser.cpp +++ b/velox/substrait/SubstraitParser.cpp @@ -144,14 +144,12 @@ std::shared_ptr SubstraitParser::parseType( return std::make_shared(type); } -std::string SubstraitParser::parseType( - const std::string& substraitType) { - auto it = typeMap_.find(substraitType); - if (it == typeMap_.end()) { - VELOX_NYI( - "Substrait parsing for type {} not supported.", substraitType); - } - return it->second; +std::string SubstraitParser::parseType(const std::string& substraitType) { + auto it = typeMap_.find(substraitType); + if (it == typeMap_.end()) { + VELOX_NYI("Substrait parsing for type {} not supported.", substraitType); + } + return it->second; }; std::vector> @@ -286,7 +284,7 @@ void SubstraitParser::getSubFunctionTypes( std::string delimiter = "_"; while ((pos = funcTypes.find(delimiter)) != std::string::npos) { auto type = funcTypes.substr(0, pos); - if (type != "opt" && type !="req") { + if (type != "opt" && type != "req") { types.emplace_back(type); } funcTypes.erase(0, pos + delimiter.length()); @@ -314,4 +312,19 @@ std::string SubstraitParser::mapToVeloxFunction( return subFunc; } +bool SubstraitParser::configSetInOptimization( + const ::substrait::extensions::AdvancedExtension& extension, + const std::string& config) const { + if (extension.has_optimization()) { + google::protobuf::StringValue msg; + extension.optimization().UnpackTo(&msg); + std::size_t pos = msg.value().find(config); + if ((pos != std::string::npos) && + (msg.value().substr(pos + config.size(), 1) == "1")) { + return true; + } + } + return false; +} + } // namespace facebook::velox::substrait diff --git a/velox/substrait/SubstraitParser.h b/velox/substrait/SubstraitParser.h index c836eb688194..682766657528 100644 --- a/velox/substrait/SubstraitParser.h +++ b/velox/substrait/SubstraitParser.h @@ -25,6 +25,8 @@ #include "velox/substrait/proto/substrait/type.pb.h" #include "velox/substrait/proto/substrait/type_expressions.pb.h" +#include + namespace facebook::velox::substrait { /// This class contains some common functions used to parse Substrait @@ -94,6 +96,15 @@ class SubstraitParser { /// Map the Substrait function keyword into Velox function keyword. std::string mapToVeloxFunction(const std::string& substraitFunction) const; + /// @brief Return whether a config is set as true in AdvancedExtension + /// optimization. + /// @param extension Substrait advanced extension. + /// @param config the key string of a config. + /// @return Whether the config is set as true. + bool configSetInOptimization( + const ::substrait::extensions::AdvancedExtension& extension, + const std::string& config) const; + private: /// A map used for mapping Substrait function keywords into Velox functions' /// keywords. Key: the Substrait function keyword, Value: the Velox function diff --git a/velox/substrait/SubstraitToVeloxPlan.cpp b/velox/substrait/SubstraitToVeloxPlan.cpp index c55ead8016b3..b3cf9d425eb5 100644 --- a/velox/substrait/SubstraitToVeloxPlan.cpp +++ b/velox/substrait/SubstraitToVeloxPlan.cpp @@ -16,8 +16,6 @@ #include "velox/substrait/SubstraitToVeloxPlan.h" -#include - #include "velox/substrait/TypeUtils.h" #include "velox/substrait/VariantToVectorConverter.h" #include "velox/type/Type.h" @@ -80,26 +78,6 @@ const std::string sNot = "not"; const std::string sI32 = "i32"; const std::string sI64 = "i64"; -/// @brief Return whether a config is set as true in AdvancedExtension -/// optimization. -/// @param extension Substrait advanced extension. -/// @param config the key string of a config. -/// @return Whether the config is set as true. -bool configSetInOptimization( - const ::substrait::extensions::AdvancedExtension& extension, - const std::string& config) { - if (extension.has_optimization()) { - google::protobuf::StringValue msg; - extension.optimization().UnpackTo(&msg); - std::size_t pos = msg.value().find(config); - if ((pos != std::string::npos) && - (msg.value().substr(pos + config.size(), 1) == "1")) { - return true; - } - } - return false; -} - /// @brief Get the input type from both sides of join. /// @param leftNode the plan node of left side. /// @param rightNode the plan node of right side. @@ -236,7 +214,7 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: // Determine the semi join type based on extracted information. if (sJoin.has_advanced_extension() && - configSetInOptimization( + subParser_->configSetInOptimization( sJoin.advanced_extension(), "isExistenceJoin=")) { joinType = core::JoinType::kLeftSemiProject; } else { @@ -246,7 +224,7 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI: // Determine the semi join type based on extracted information. if (sJoin.has_advanced_extension() && - configSetInOptimization( + subParser_->configSetInOptimization( sJoin.advanced_extension(), "isExistenceJoin=")) { joinType = core::JoinType::kRightSemiProject; } else { @@ -256,7 +234,7 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_ANTI: { // Determine the anti join type based on extracted information. if (sJoin.has_advanced_extension() && - configSetInOptimization( + subParser_->configSetInOptimization( sJoin.advanced_extension(), "isNullAwareAntiJoin=")) { joinType = core::JoinType::kNullAwareAnti; } else { @@ -293,16 +271,32 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( exprConverter_->toVeloxExpr(sJoin.post_join_filter(), inputRowType); } - // Create join node - return std::make_shared( - nextPlanNodeId(), - joinType, - leftKeys, - rightKeys, - filter, - leftNode, - rightNode, - getJoinOutputType(leftNode, rightNode, joinType)); + if (sJoin.has_advanced_extension() && + subParser_->configSetInOptimization( + sJoin.advanced_extension(), "isSMJ=")) { + // Create MergeJoinNode node + return std::make_shared( + nextPlanNodeId(), + joinType, + leftKeys, + rightKeys, + filter, + leftNode, + rightNode, + getJoinOutputType(leftNode, rightNode, joinType)); + + } else { + // Create HashJoinNode node + return std::make_shared( + nextPlanNodeId(), + joinType, + leftKeys, + rightKeys, + filter, + leftNode, + rightNode, + getJoinOutputType(leftNode, rightNode, joinType)); + } } core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( diff --git a/velox/substrait/SubstraitToVeloxPlanValidator.cpp b/velox/substrait/SubstraitToVeloxPlanValidator.cpp index 89bf331e9fcb..ec108f0050d4 100644 --- a/velox/substrait/SubstraitToVeloxPlanValidator.cpp +++ b/velox/substrait/SubstraitToVeloxPlanValidator.cpp @@ -429,6 +429,19 @@ bool SubstraitToVeloxPlanValidator::validate( return false; } + if (sJoin.has_advanced_extension() && + subParser_->configSetInOptimization( + sJoin.advanced_extension(), "isSMJ=")) { + switch (sJoin.type()) { + case ::substrait::JoinRel_JoinType_JOIN_TYPE_INNER: + case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT: + break; + default: + std::cout << "Sort merge join only support inner and left join" + << std::endl; + return false; + } + } switch (sJoin.type()) { case ::substrait::JoinRel_JoinType_JOIN_TYPE_INNER: case ::substrait::JoinRel_JoinType_JOIN_TYPE_OUTER: