diff --git a/velox/substrait/SubstraitToVeloxPlan.cpp b/velox/substrait/SubstraitToVeloxPlan.cpp index ee482cbee273..99b219f1bd2d 100644 --- a/velox/substrait/SubstraitToVeloxPlan.cpp +++ b/velox/substrait/SubstraitToVeloxPlan.cpp @@ -99,20 +99,33 @@ const std::string sNot = "not"; // Substrait types. const std::string sI32 = "i32"; const std::string sI64 = "i64"; -} // namespace -core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( - const ::substrait::JoinRel& sJoin) { - if (!sJoin.has_left()) { - VELOX_FAIL("Left Rel is expected in JoinRel."); - } - if (!sJoin.has_right()) { - VELOX_FAIL("Right Rel is expected in JoinRel."); +/// @brief Return whether a config is set as true in AdvancedExtension +/// optimization. +/// @param extension Substrait advanced extension. +/// @param config the key string of a config. +/// @return Whether the config is set as true. +bool configSetInOptimization( + const ::substrait::extensions::AdvancedExtension& extension, + const std::string& config) { + if (extension.has_optimization()) { + std::string msg = extension.optimization().value(); + std::size_t pos = msg.find(config); + if ((pos != std::string::npos) && + (msg.substr(pos + config.size(), 1) == "1")) { + return true; + } } + return false; +} - auto leftNode = toVeloxPlan(sJoin.left()); - auto rightNode = toVeloxPlan(sJoin.right()); - +/// @brief Get the input type from both sides of join. +/// @param leftNode the plan node of left side. +/// @param rightNode the plan node of right side. +/// @return the input type. +RowTypePtr getJoinInputType( + const core::PlanNodePtr& leftNode, + const core::PlanNodePtr& rightNode) { auto outputSize = leftNode->outputType()->size() + rightNode->outputType()->size(); std::vector outputNames; @@ -125,34 +138,82 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( const auto& types = node->outputType()->children(); outputTypes.insert(outputTypes.end(), types.begin(), types.end()); } - auto outputRowType = std::make_shared( + return std::make_shared( std::move(outputNames), std::move(outputTypes)); +} - // extract join keys from join expression - std::vector leftExprs, - rightExprs; - extractJoinKeys(sJoin.expression(), leftExprs, rightExprs); - VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size()); - size_t numKeys = leftExprs.size(); +/// @brief Get the direct output type of join. +/// @param leftNode the plan node of left side. +/// @param rightNode the plan node of right side. +/// @param joinType the join type. +/// @return the output type. +RowTypePtr getJoinOutputType( + const core::PlanNodePtr& leftNode, + const core::PlanNodePtr& rightNode, + const core::JoinType& joinType) { + // Decide output type. + // Output of right semi join cannot include columns from the left side. + bool outputMayIncludeLeftColumns = + !(core::isRightSemiFilterJoin(joinType) || + core::isRightSemiProjectJoin(joinType)); + + // Output of left semi and anti joins cannot include columns from the right + // side. + bool outputMayIncludeRightColumns = + !(core::isLeftSemiFilterJoin(joinType) || + core::isLeftSemiProjectJoin(joinType) || core::isAntiJoin(joinType) || + core::isNullAwareAntiJoin(joinType)); + + if (outputMayIncludeLeftColumns && outputMayIncludeRightColumns) { + return getJoinInputType(leftNode, rightNode); + } + + if (outputMayIncludeLeftColumns) { + if (core::isLeftSemiProjectJoin(joinType)) { + auto outputSize = leftNode->outputType()->size() + 1; + std::vector outputNames = leftNode->outputType()->names(); + std::vector> outputTypes = + leftNode->outputType()->children(); + outputNames.emplace_back("exists"); + outputTypes.emplace_back(BOOLEAN()); + return std::make_shared( + std::move(outputNames), std::move(outputTypes)); + } else { + return leftNode->outputType(); + } + } - std::vector> leftKeys, - rightKeys; - leftKeys.reserve(numKeys); - rightKeys.reserve(numKeys); - for (size_t i = 0; i < numKeys; ++i) { - leftKeys.emplace_back( - exprConverter_->toVeloxExpr(*leftExprs[i], outputRowType)); - rightKeys.emplace_back( - exprConverter_->toVeloxExpr(*rightExprs[i], outputRowType)); + if (outputMayIncludeRightColumns) { + if (core::isRightSemiProjectJoin(joinType)) { + auto outputSize = rightNode->outputType()->size() + 1; + std::vector outputNames = rightNode->outputType()->names(); + std::vector> outputTypes = + rightNode->outputType()->children(); + outputNames.emplace_back("exists"); + outputTypes.emplace_back(BOOLEAN()); + return std::make_shared( + std::move(outputNames), std::move(outputTypes)); + } else { + return rightNode->outputType(); + } } + VELOX_FAIL("Output should include left or right columns."); +} +} // namespace - std::shared_ptr filter; - if (sJoin.has_post_join_filter()) { - filter = - exprConverter_->toVeloxExpr(sJoin.post_join_filter(), outputRowType); +core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( + const ::substrait::JoinRel& sJoin) { + if (!sJoin.has_left()) { + VELOX_FAIL("Left Rel is expected in JoinRel."); + } + if (!sJoin.has_right()) { + VELOX_FAIL("Right Rel is expected in JoinRel."); } - // Map join type + auto leftNode = toVeloxPlan(sJoin.left()); + auto rightNode = toVeloxPlan(sJoin.right()); + + // Map join type. core::JoinType joinType; switch (sJoin.type()) { case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER: @@ -168,25 +229,30 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( joinType = core::JoinType::kRight; break; case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: - joinType = core::JoinType::kLeftSemi; + // Determine the semi join type based on extracted information. + if (sJoin.has_advanced_extension() && + configSetInOptimization( + sJoin.advanced_extension(), "isExistenceJoin=")) { + joinType = core::JoinType::kLeftSemiProject; + } else { + joinType = core::JoinType::kLeftSemi; + } break; case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI: - joinType = core::JoinType::kRightSemi; + // Determine the semi join type based on extracted information. + if (sJoin.has_advanced_extension() && + configSetInOptimization( + sJoin.advanced_extension(), "isExistenceJoin=")) { + joinType = core::JoinType::kRightSemiProject; + } else { + joinType = core::JoinType::kRightSemi; + } break; case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_ANTI: { // Determine the anti join type based on extracted information. - bool isNullAwareAntiJoin = false; if (sJoin.has_advanced_extension() && - sJoin.advanced_extension().has_optimization()) { - std::string msg = sJoin.advanced_extension().optimization().value(); - std::string nullAwareKey = "isNullAwareAntiJoin="; - std::size_t pos = msg.find(nullAwareKey); - if ((pos != std::string::npos) && - (msg.substr(pos + nullAwareKey.size(), 1) == "1")) { - isNullAwareAntiJoin = true; - } - } - if (isNullAwareAntiJoin) { + configSetInOptimization( + sJoin.advanced_extension(), "isNullAwareAntiJoin=")) { joinType = core::JoinType::kNullAwareAnti; } else { joinType = core::JoinType::kAnti; @@ -197,6 +263,31 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( VELOX_NYI("Unsupported Join type: {}", sJoin.type()); } + // extract join keys from join expression + std::vector leftExprs, + rightExprs; + extractJoinKeys(sJoin.expression(), leftExprs, rightExprs); + VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size()); + size_t numKeys = leftExprs.size(); + + std::vector> leftKeys, + rightKeys; + leftKeys.reserve(numKeys); + rightKeys.reserve(numKeys); + auto inputRowType = getJoinInputType(leftNode, rightNode); + for (size_t i = 0; i < numKeys; ++i) { + leftKeys.emplace_back( + exprConverter_->toVeloxExpr(*leftExprs[i], inputRowType)); + rightKeys.emplace_back( + exprConverter_->toVeloxExpr(*rightExprs[i], inputRowType)); + } + + std::shared_ptr filter; + if (sJoin.has_post_join_filter()) { + filter = + exprConverter_->toVeloxExpr(sJoin.post_join_filter(), inputRowType); + } + // Create join node return std::make_shared( nextPlanNodeId(), @@ -206,7 +297,7 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( filter, leftNode, rightNode, - outputRowType); + getJoinOutputType(leftNode, rightNode, joinType)); } core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(