Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix semi join output type and support existence join #67

Merged
merged 2 commits into from
Nov 10, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 137 additions & 46 deletions velox/substrait/SubstraitToVeloxPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,33 @@ const std::string sNot = "not";
// Substrait types.
const std::string sI32 = "i32";
const std::string sI64 = "i64";
} // namespace

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::JoinRel& sJoin) {
if (!sJoin.has_left()) {
VELOX_FAIL("Left Rel is expected in JoinRel.");
}
if (!sJoin.has_right()) {
VELOX_FAIL("Right Rel is expected in JoinRel.");
/// @brief Return whether a config is set as true in AdvancedExtension
/// optimization.
/// @param extension Substrait advanced extension.
/// @param config the key string of a config.
/// @return Whether the config is set as true.
bool configSetInOptimization(
const ::substrait::extensions::AdvancedExtension& extension,
const std::string& config) {
if (extension.has_optimization()) {
std::string msg = extension.optimization().value();
std::size_t pos = msg.find(config);
if ((pos != std::string::npos) &&
(msg.substr(pos + config.size(), 1) == "1")) {
return true;
}
}
return false;
}

auto leftNode = toVeloxPlan(sJoin.left());
auto rightNode = toVeloxPlan(sJoin.right());

/// @brief Get the input type from both sides of join.
/// @param leftNode the plan node of left side.
/// @param rightNode the plan node of right side.
/// @return the input type.
RowTypePtr getJoinInputType(
const core::PlanNodePtr& leftNode,
const core::PlanNodePtr& rightNode) {
auto outputSize =
leftNode->outputType()->size() + rightNode->outputType()->size();
std::vector<std::string> outputNames;
Expand All @@ -108,34 +121,82 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const auto& types = node->outputType()->children();
outputTypes.insert(outputTypes.end(), types.begin(), types.end());
}
auto outputRowType = std::make_shared<const RowType>(
return std::make_shared<const RowType>(
std::move(outputNames), std::move(outputTypes));
}

// extract join keys from join expression
std::vector<const ::substrait::Expression::FieldReference*> leftExprs,
rightExprs;
extractJoinKeys(sJoin.expression(), leftExprs, rightExprs);
VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size());
size_t numKeys = leftExprs.size();
/// @brief Get the direct output type of join.
/// @param leftNode the plan node of left side.
/// @param rightNode the plan node of right side.
/// @param joinType the join type.
/// @return the output type.
RowTypePtr getJoinOutputType(
const core::PlanNodePtr& leftNode,
const core::PlanNodePtr& rightNode,
const core::JoinType& joinType) {
// Decide output type.
// Output of right semi join cannot include columns from the left side.
bool outputMayIncludeLeftColumns =
!(core::isRightSemiFilterJoin(joinType) ||
core::isRightSemiProjectJoin(joinType));

// Output of left semi and anti joins cannot include columns from the right
// side.
bool outputMayIncludeRightColumns =
!(core::isLeftSemiFilterJoin(joinType) ||
core::isLeftSemiProjectJoin(joinType) || core::isAntiJoin(joinType) ||
core::isNullAwareAntiJoin(joinType));

if (outputMayIncludeLeftColumns && outputMayIncludeRightColumns) {
return getJoinInputType(leftNode, rightNode);
}

if (outputMayIncludeLeftColumns) {
if (core::isLeftSemiProjectJoin(joinType)) {
auto outputSize = leftNode->outputType()->size() + 1;
std::vector<std::string> outputNames = leftNode->outputType()->names();
std::vector<std::shared_ptr<const Type>> outputTypes =
leftNode->outputType()->children();
outputNames.emplace_back("exists");
outputTypes.emplace_back(BOOLEAN());
return std::make_shared<const RowType>(
std::move(outputNames), std::move(outputTypes));
} else {
return leftNode->outputType();
}
}

std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> leftKeys,
rightKeys;
leftKeys.reserve(numKeys);
rightKeys.reserve(numKeys);
for (size_t i = 0; i < numKeys; ++i) {
leftKeys.emplace_back(
exprConverter_->toVeloxExpr(*leftExprs[i], outputRowType));
rightKeys.emplace_back(
exprConverter_->toVeloxExpr(*rightExprs[i], outputRowType));
if (outputMayIncludeRightColumns) {
if (core::isRightSemiProjectJoin(joinType)) {
auto outputSize = rightNode->outputType()->size() + 1;
std::vector<std::string> outputNames = rightNode->outputType()->names();
std::vector<std::shared_ptr<const Type>> outputTypes =
rightNode->outputType()->children();
outputNames.emplace_back("exists");
outputTypes.emplace_back(BOOLEAN());
return std::make_shared<const RowType>(
std::move(outputNames), std::move(outputTypes));
} else {
return rightNode->outputType();
}
}
VELOX_FAIL("Output should include left or right columns.");
}
} // namespace

std::shared_ptr<const core::ITypedExpr> filter;
if (sJoin.has_post_join_filter()) {
filter =
exprConverter_->toVeloxExpr(sJoin.post_join_filter(), outputRowType);
core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::JoinRel& sJoin) {
if (!sJoin.has_left()) {
VELOX_FAIL("Left Rel is expected in JoinRel.");
}
if (!sJoin.has_right()) {
VELOX_FAIL("Right Rel is expected in JoinRel.");
}

// Map join type
auto leftNode = toVeloxPlan(sJoin.left());
auto rightNode = toVeloxPlan(sJoin.right());

// Map join type.
core::JoinType joinType;
switch (sJoin.type()) {
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER:
Expand All @@ -151,25 +212,30 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
joinType = core::JoinType::kRight;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
joinType = core::JoinType::kLeftSemi;
// Determine the semi join type based on extracted information.
if (sJoin.has_advanced_extension() &&
configSetInOptimization(
sJoin.advanced_extension(), "isExistenceJoin=")) {
joinType = core::JoinType::kLeftSemiProject;
} else {
joinType = core::JoinType::kLeftSemi;
}
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI:
joinType = core::JoinType::kRightSemi;
// Determine the semi join type based on extracted information.
if (sJoin.has_advanced_extension() &&
configSetInOptimization(
sJoin.advanced_extension(), "isExistenceJoin=")) {
joinType = core::JoinType::kRightSemiProject;
} else {
joinType = core::JoinType::kRightSemi;
}
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_ANTI: {
// Determine the anti join type based on extracted information.
bool isNullAwareAntiJoin = false;
if (sJoin.has_advanced_extension() &&
sJoin.advanced_extension().has_optimization()) {
std::string msg = sJoin.advanced_extension().optimization().value();
std::string nullAwareKey = "isNullAwareAntiJoin=";
std::size_t pos = msg.find(nullAwareKey);
if ((pos != std::string::npos) &&
(msg.substr(pos + nullAwareKey.size(), 1) == "1")) {
isNullAwareAntiJoin = true;
}
}
if (isNullAwareAntiJoin) {
configSetInOptimization(
sJoin.advanced_extension(), "isNullAwareAntiJoin=")) {
joinType = core::JoinType::kNullAwareAnti;
} else {
joinType = core::JoinType::kAnti;
Expand All @@ -180,6 +246,31 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
VELOX_NYI("Unsupported Join type: {}", sJoin.type());
}

// extract join keys from join expression
std::vector<const ::substrait::Expression::FieldReference*> leftExprs,
rightExprs;
extractJoinKeys(sJoin.expression(), leftExprs, rightExprs);
VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size());
size_t numKeys = leftExprs.size();

std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> leftKeys,
rightKeys;
leftKeys.reserve(numKeys);
rightKeys.reserve(numKeys);
auto inputRowType = getJoinInputType(leftNode, rightNode);
for (size_t i = 0; i < numKeys; ++i) {
leftKeys.emplace_back(
exprConverter_->toVeloxExpr(*leftExprs[i], inputRowType));
rightKeys.emplace_back(
exprConverter_->toVeloxExpr(*rightExprs[i], inputRowType));
}

std::shared_ptr<const core::ITypedExpr> filter;
if (sJoin.has_post_join_filter()) {
filter =
exprConverter_->toVeloxExpr(sJoin.post_join_filter(), inputRowType);
}

// Create join node
return std::make_shared<core::HashJoinNode>(
nextPlanNodeId(),
Expand All @@ -189,7 +280,7 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
filter,
leftNode,
rightNode,
outputRowType);
getJoinOutputType(leftNode, rightNode, joinType));
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
Expand Down