Skip to content

Commit

Permalink
Fix semi join output type and support existence join (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo authored and zhejiangxiaomai committed Jan 11, 2023
1 parent 4047414 commit 3679631
Showing 1 changed file with 137 additions and 46 deletions.
183 changes: 137 additions & 46 deletions velox/substrait/SubstraitToVeloxPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,33 @@ const std::string sNot = "not";
// Substrait types.
const std::string sI32 = "i32";
const std::string sI64 = "i64";
} // namespace

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::JoinRel& sJoin) {
if (!sJoin.has_left()) {
VELOX_FAIL("Left Rel is expected in JoinRel.");
}
if (!sJoin.has_right()) {
VELOX_FAIL("Right Rel is expected in JoinRel.");
/// @brief Return whether a config is set as true in AdvancedExtension
/// optimization.
/// @param extension Substrait advanced extension.
/// @param config the key string of a config.
/// @return Whether the config is set as true.
bool configSetInOptimization(
const ::substrait::extensions::AdvancedExtension& extension,
const std::string& config) {
if (extension.has_optimization()) {
std::string msg = extension.optimization().value();
std::size_t pos = msg.find(config);
if ((pos != std::string::npos) &&
(msg.substr(pos + config.size(), 1) == "1")) {
return true;
}
}
return false;
}

auto leftNode = toVeloxPlan(sJoin.left());
auto rightNode = toVeloxPlan(sJoin.right());

/// @brief Get the input type from both sides of join.
/// @param leftNode the plan node of left side.
/// @param rightNode the plan node of right side.
/// @return the input type.
RowTypePtr getJoinInputType(
const core::PlanNodePtr& leftNode,
const core::PlanNodePtr& rightNode) {
auto outputSize =
leftNode->outputType()->size() + rightNode->outputType()->size();
std::vector<std::string> outputNames;
Expand All @@ -125,34 +138,82 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const auto& types = node->outputType()->children();
outputTypes.insert(outputTypes.end(), types.begin(), types.end());
}
auto outputRowType = std::make_shared<const RowType>(
return std::make_shared<const RowType>(
std::move(outputNames), std::move(outputTypes));
}

// extract join keys from join expression
std::vector<const ::substrait::Expression::FieldReference*> leftExprs,
rightExprs;
extractJoinKeys(sJoin.expression(), leftExprs, rightExprs);
VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size());
size_t numKeys = leftExprs.size();
/// @brief Get the direct output type of join.
/// @param leftNode the plan node of left side.
/// @param rightNode the plan node of right side.
/// @param joinType the join type.
/// @return the output type.
RowTypePtr getJoinOutputType(
const core::PlanNodePtr& leftNode,
const core::PlanNodePtr& rightNode,
const core::JoinType& joinType) {
// Decide output type.
// Output of right semi join cannot include columns from the left side.
bool outputMayIncludeLeftColumns =
!(core::isRightSemiFilterJoin(joinType) ||
core::isRightSemiProjectJoin(joinType));

// Output of left semi and anti joins cannot include columns from the right
// side.
bool outputMayIncludeRightColumns =
!(core::isLeftSemiFilterJoin(joinType) ||
core::isLeftSemiProjectJoin(joinType) || core::isAntiJoin(joinType) ||
core::isNullAwareAntiJoin(joinType));

if (outputMayIncludeLeftColumns && outputMayIncludeRightColumns) {
return getJoinInputType(leftNode, rightNode);
}

if (outputMayIncludeLeftColumns) {
if (core::isLeftSemiProjectJoin(joinType)) {
auto outputSize = leftNode->outputType()->size() + 1;
std::vector<std::string> outputNames = leftNode->outputType()->names();
std::vector<std::shared_ptr<const Type>> outputTypes =
leftNode->outputType()->children();
outputNames.emplace_back("exists");
outputTypes.emplace_back(BOOLEAN());
return std::make_shared<const RowType>(
std::move(outputNames), std::move(outputTypes));
} else {
return leftNode->outputType();
}
}

std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> leftKeys,
rightKeys;
leftKeys.reserve(numKeys);
rightKeys.reserve(numKeys);
for (size_t i = 0; i < numKeys; ++i) {
leftKeys.emplace_back(
exprConverter_->toVeloxExpr(*leftExprs[i], outputRowType));
rightKeys.emplace_back(
exprConverter_->toVeloxExpr(*rightExprs[i], outputRowType));
if (outputMayIncludeRightColumns) {
if (core::isRightSemiProjectJoin(joinType)) {
auto outputSize = rightNode->outputType()->size() + 1;
std::vector<std::string> outputNames = rightNode->outputType()->names();
std::vector<std::shared_ptr<const Type>> outputTypes =
rightNode->outputType()->children();
outputNames.emplace_back("exists");
outputTypes.emplace_back(BOOLEAN());
return std::make_shared<const RowType>(
std::move(outputNames), std::move(outputTypes));
} else {
return rightNode->outputType();
}
}
VELOX_FAIL("Output should include left or right columns.");
}
} // namespace

std::shared_ptr<const core::ITypedExpr> filter;
if (sJoin.has_post_join_filter()) {
filter =
exprConverter_->toVeloxExpr(sJoin.post_join_filter(), outputRowType);
core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::JoinRel& sJoin) {
if (!sJoin.has_left()) {
VELOX_FAIL("Left Rel is expected in JoinRel.");
}
if (!sJoin.has_right()) {
VELOX_FAIL("Right Rel is expected in JoinRel.");
}

// Map join type
auto leftNode = toVeloxPlan(sJoin.left());
auto rightNode = toVeloxPlan(sJoin.right());

// Map join type.
core::JoinType joinType;
switch (sJoin.type()) {
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER:
Expand All @@ -168,25 +229,30 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
joinType = core::JoinType::kRight;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
joinType = core::JoinType::kLeftSemi;
// Determine the semi join type based on extracted information.
if (sJoin.has_advanced_extension() &&
configSetInOptimization(
sJoin.advanced_extension(), "isExistenceJoin=")) {
joinType = core::JoinType::kLeftSemiProject;
} else {
joinType = core::JoinType::kLeftSemi;
}
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI:
joinType = core::JoinType::kRightSemi;
// Determine the semi join type based on extracted information.
if (sJoin.has_advanced_extension() &&
configSetInOptimization(
sJoin.advanced_extension(), "isExistenceJoin=")) {
joinType = core::JoinType::kRightSemiProject;
} else {
joinType = core::JoinType::kRightSemi;
}
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_ANTI: {
// Determine the anti join type based on extracted information.
bool isNullAwareAntiJoin = false;
if (sJoin.has_advanced_extension() &&
sJoin.advanced_extension().has_optimization()) {
std::string msg = sJoin.advanced_extension().optimization().value();
std::string nullAwareKey = "isNullAwareAntiJoin=";
std::size_t pos = msg.find(nullAwareKey);
if ((pos != std::string::npos) &&
(msg.substr(pos + nullAwareKey.size(), 1) == "1")) {
isNullAwareAntiJoin = true;
}
}
if (isNullAwareAntiJoin) {
configSetInOptimization(
sJoin.advanced_extension(), "isNullAwareAntiJoin=")) {
joinType = core::JoinType::kNullAwareAnti;
} else {
joinType = core::JoinType::kAnti;
Expand All @@ -197,6 +263,31 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
VELOX_NYI("Unsupported Join type: {}", sJoin.type());
}

// extract join keys from join expression
std::vector<const ::substrait::Expression::FieldReference*> leftExprs,
rightExprs;
extractJoinKeys(sJoin.expression(), leftExprs, rightExprs);
VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size());
size_t numKeys = leftExprs.size();

std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> leftKeys,
rightKeys;
leftKeys.reserve(numKeys);
rightKeys.reserve(numKeys);
auto inputRowType = getJoinInputType(leftNode, rightNode);
for (size_t i = 0; i < numKeys; ++i) {
leftKeys.emplace_back(
exprConverter_->toVeloxExpr(*leftExprs[i], inputRowType));
rightKeys.emplace_back(
exprConverter_->toVeloxExpr(*rightExprs[i], inputRowType));
}

std::shared_ptr<const core::ITypedExpr> filter;
if (sJoin.has_post_join_filter()) {
filter =
exprConverter_->toVeloxExpr(sJoin.post_join_filter(), inputRowType);
}

// Create join node
return std::make_shared<core::HashJoinNode>(
nextPlanNodeId(),
Expand All @@ -206,7 +297,7 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
filter,
leftNode,
rightNode,
outputRowType);
getJoinOutputType(leftNode, rightNode, joinType));
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
Expand Down

0 comments on commit 3679631

Please sign in to comment.