Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the window support #61

Merged
merged 1 commit into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions velox/functions/prestosql/window/Rank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ void registerRankInternal(
const std::vector<exec::WindowFunctionArg>& /*args*/,
const TypePtr& resultType,
velox::memory::MemoryPool* /*pool*/,
HashStringAllocator* /*stringAllocator*/)
-> std::unique_ptr<exec::WindowFunction> {
HashStringAllocator *
/*stringAllocator*/) -> std::unique_ptr<exec::WindowFunction> {
return std::make_unique<RankFunction<TRank, TResult>>(resultType);
});
}

void registerRank(const std::string& name) {
registerRankInternal<RankType::kRank, int64_t>(name, "bigint");
registerRankInternal<RankType::kRank, int32_t>(name, "integer");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this change used to align Rank with Spark? In the long term, maybe we could consider to implement aligned Rank under sparksql.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Will implement the Rank under sparksql function later.

}
void registerDenseRank(const std::string& name) {
registerRankInternal<RankType::kDenseRank, int64_t>(name, "bigint");
Expand Down
6 changes: 3 additions & 3 deletions velox/functions/prestosql/window/RowNumber.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class RowNumberFunction : public exec::WindowFunction {
vector_size_t resultOffset,
const VectorPtr& result) override {
int numRows = peerGroupStarts->size() / sizeof(vector_size_t);
auto* rawValues = result->asFlatVector<int64_t>()->mutableRawValues();
auto* rawValues = result->asFlatVector<int32_t>()->mutableRawValues();
for (int i = 0; i < numRows; i++) {
rawValues[resultOffset + i] = rowNumber_++;
}
Expand All @@ -64,8 +64,8 @@ void registerRowNumber(const std::string& name) {
const std::vector<exec::WindowFunctionArg>& /*args*/,
const TypePtr& /*resultType*/,
velox::memory::MemoryPool* /*pool*/,
HashStringAllocator* /*stringAllocator*/)
-> std::unique_ptr<exec::WindowFunction> {
HashStringAllocator *
/*stringAllocator*/) -> std::unique_ptr<exec::WindowFunction> {
return std::make_unique<RowNumberFunction>();
});
}
Expand Down
152 changes: 152 additions & 0 deletions velox/substrait/SubstraitToVeloxPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,155 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
childNode);
}

const core::WindowNode::Frame createWindowFrame(
const ::substrait::Expression_WindowFunction_Bound& lower_bound,
const ::substrait::Expression_WindowFunction_Bound& upper_bound,
const ::substrait::WindowType& type) {
core::WindowNode::Frame frame;
switch (type) {
case ::substrait::WindowType::ROWS:
frame.type = core::WindowNode::WindowType::kRows;
break;
case ::substrait::WindowType::RANGE:

frame.type = core::WindowNode::WindowType::kRange;
break;
default:
VELOX_FAIL(
"the window type only support ROWS and RANGE, and the input type is ",
type);
}

auto boundTypeConversion =
[](::substrait::Expression_WindowFunction_Bound boundType)
-> core::WindowNode::BoundType {
if (boundType.has_current_row()) {
return core::WindowNode::BoundType::kCurrentRow;
} else if (boundType.has_unbounded_following()) {
return core::WindowNode::BoundType::kUnboundedFollowing;
} else if (boundType.has_unbounded_preceding()) {
return core::WindowNode::BoundType::kUnboundedPreceding;
} else {
VELOX_FAIL("The BoundType is not supported.");
}
};
frame.startType = boundTypeConversion(lower_bound);
frame.startValue = nullptr;
frame.endType = boundTypeConversion(upper_bound);
frame.endValue = nullptr;
return frame;
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::WindowRel& windowRel) {
core::PlanNodePtr childNode;
if (windowRel.has_input()) {
childNode = toVeloxPlan(windowRel.input());
} else {
VELOX_FAIL("Child Rel is expected in WindowRel.");
}

const auto& inputType = childNode->outputType();

// Parse measures and get the window expressions.
// Each measure represents one window expression.
bool ignoreNullKeys = false;
std::vector<core::WindowNode::Function> windowNodeFunctions;
std::vector<std::string> windowColumnNames;

windowNodeFunctions.reserve(windowRel.measures().size());
for (const auto& smea : windowRel.measures()) {
const auto& windowFunction = smea.measure();
std::string funcName = subParser_->findVeloxFunction(
functionMap_, windowFunction.function_reference());
std::vector<std::shared_ptr<const core::ITypedExpr>> windowParams;
windowParams.reserve(windowFunction.arguments().size());
for (const auto& arg : windowFunction.arguments()) {
windowParams.emplace_back(
exprConverter_->toVeloxExpr(arg.value(), inputType));
}
auto windowVeloxType =
toVeloxType(subParser_->parseType(windowFunction.output_type())->type);
auto windowCall = std::make_shared<const core::CallTypedExpr>(
windowVeloxType, std::move(windowParams), funcName);
auto upperBound = windowFunction.upper_bound();
auto lowerBound = windowFunction.lower_bound();
auto type = windowFunction.window_type();

windowColumnNames.push_back(windowFunction.column_name());

windowNodeFunctions.push_back(
{std::move(windowCall),
createWindowFrame(lowerBound, upperBound, type),
ignoreNullKeys});
}

// Construct partitionKeys
std::vector<core::FieldAccessTypedExprPtr> partitionKeys;
const auto& partitions = windowRel.partition_expressions();
partitionKeys.reserve(partitions.size());
for (const auto& partition : partitions) {
auto expression = exprConverter_->toVeloxExpr(partition, inputType);
auto expr_field =
dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
VELOX_CHECK(
expr_field != nullptr,
" the partition key in Window Operator only support field")

partitionKeys.emplace_back(
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
expression));
}

std::vector<core::FieldAccessTypedExprPtr> sortingKeys;
std::vector<core::SortOrder> sortingOrders;

const auto& sorts = windowRel.sorts();
sortingKeys.reserve(sorts.size());
sortingOrders.reserve(sorts.size());

for (const auto& sort : sorts) {
switch (sort.direction()) {
case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST:
sortingOrders.emplace_back(core::kAscNullsFirst);
break;
case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST:
sortingOrders.emplace_back(core::kAscNullsLast);
break;
case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST:
sortingOrders.emplace_back(core::kDescNullsFirst);
break;
case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST:
sortingOrders.emplace_back(core::kDescNullsLast);
break;
default:
VELOX_FAIL("Sort direction is not support in WindowRel");
}

if (sort.has_expr()) {
auto expression = exprConverter_->toVeloxExpr(sort.expr(), inputType);
auto expr_field =
dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
VELOX_CHECK(
expr_field != nullptr,
" the sorting key in Window Operator only support field")

sortingKeys.emplace_back(
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
expression));
}
}

return std::make_shared<core::WindowNode>(
nextPlanNodeId(),
partitionKeys,
sortingKeys,
sortingOrders,
windowColumnNames,
windowNodeFunctions,
childNode);
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::SortRel& sortRel) {
core::PlanNodePtr childNode;
Expand Down Expand Up @@ -836,6 +985,9 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
if (sRel.has_fetch()) {
return toVeloxPlan(sRel.fetch());
}
if (sRel.has_window()) {
return toVeloxPlan(sRel.window());
}
VELOX_NYI("Substrait conversion not supported for Rel.");
}

Expand Down
3 changes: 3 additions & 0 deletions velox/substrait/SubstraitToVeloxPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class SubstraitVeloxPlanConverter {
/// Used to convert Substrait ExpandRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::ExpandRel& sExpand);

/// Used to convert Substrait SortRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::WindowRel& sWindow);

/// Used to convert Substrait SortRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::SortRel& sSort);

Expand Down
136 changes: 136 additions & 0 deletions velox/substrait/SubstraitToVeloxPlanValidator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,139 @@ bool SubstraitToVeloxPlanValidator::validate(
return true;
}

bool validateBoundType(::substrait::Expression_WindowFunction_Bound boundType) {
switch (boundType.kind_case()) {
case ::substrait::Expression_WindowFunction_Bound::kUnboundedFollowing:
case ::substrait::Expression_WindowFunction_Bound::kUnboundedPreceding:
case ::substrait::Expression_WindowFunction_Bound::kCurrentRow:
break;
default:
std::cout << "The Bound Type is not supported. "
<< "\n";
return false;
}
return true;
}

bool SubstraitToVeloxPlanValidator::validate(
const ::substrait::WindowRel& sWindow) {
if (sWindow.has_input() && !validate(sWindow.input())) {
return false;
}

// Get and validate the input types from extension.
if (!sWindow.has_advanced_extension()) {
std::cout << "Input types are expected in WindowRel." << std::endl;
return false;
}
const auto& extension = sWindow.advanced_extension();
std::vector<TypePtr> types;
if (!validateInputTypes(extension, types)) {
std::cout << "Validation failed for input types in WindowRel." << std::endl;
return false;
}

int32_t inputPlanNodeId = 0;
std::vector<std::string> names;
names.reserve(types.size());
for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx));
}
auto rowType = std::make_shared<RowType>(std::move(names), std::move(types));

// Validate WindowFunction
std::vector<std::string> funcSpecs;
funcSpecs.reserve(sWindow.measures().size());
for (const auto& smea : sWindow.measures()) {
try {
const auto& windowFunction = smea.measure();
funcSpecs.emplace_back(
planConverter_->findFuncSpec(windowFunction.function_reference()));
toVeloxType(subParser_->parseType(windowFunction.output_type())->type);
for (const auto& arg : windowFunction.arguments()) {
auto typeCase = arg.value().rex_type_case();
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kSelection:
case ::substrait::Expression::RexTypeCase::kLiteral:
break;
default:
std::cout << "Only field is supported in window functions."
<< std::endl;
return false;
}
}
// Validate BoundType and Frame Type
switch (windowFunction.window_type()) {
case ::substrait::WindowType::ROWS:
case ::substrait::WindowType::RANGE:
break;
default:
VELOX_FAIL(
"the window type only support ROWS and RANGE, and the input type is ",
windowFunction.window_type());
}

validateBoundType(windowFunction.upper_bound());
validateBoundType(windowFunction.lower_bound());

} catch (const VeloxException& err) {
std::cout << "Validation failed for window function due to: "
<< err.message() << std::endl;
return false;
}
}

// Validate groupby expression
const auto& groupByExprs = sWindow.partition_expressions();
std::vector<std::shared_ptr<const core::ITypedExpr>> expressions;
expressions.reserve(groupByExprs.size());
try {
for (const auto& expr : groupByExprs) {
expressions.emplace_back(exprConverter_->toVeloxExpr(expr, rowType));
}
// Try to compile the expressions. If there is any unregistred funciton or
// mismatched type, exception will be thrown.
exec::ExprSet exprSet(std::move(expressions), execCtx_);
} catch (const VeloxException& err) {
std::cout << "Validation failed for expression in ProjectRel due to:"
<< err.message() << std::endl;
return false;
}

// Validate Sort expression
const auto& sorts = sWindow.sorts();
for (const auto& sort : sorts) {
switch (sort.direction()) {
case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST:
break;
default:
return false;
}

if (sort.has_expr()) {
try {
auto expression = exprConverter_->toVeloxExpr(sort.expr(), rowType);
auto expr_field =
dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
VELOX_CHECK(
expr_field != nullptr,
" the sorting key in Sort Operator only support field")

exec::ExprSet exprSet({std::move(expression)}, execCtx_);
} catch (const VeloxException& err) {
std::cout << "Validation failed for expression in SortRel due to:"
<< err.message() << std::endl;
return false;
}
}
}

return true;
}

bool SubstraitToVeloxPlanValidator::validate(
const ::substrait::SortRel& sSort) {
if (sSort.has_input() && !validate(sSort.input())) {
Expand Down Expand Up @@ -582,6 +715,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Rel& sRel) {
if (sRel.has_fetch()) {
return validate(sRel.fetch());
}
if (sRel.has_window()) {
return validate(sRel.window());
}
return false;
}

Expand Down
3 changes: 3 additions & 0 deletions velox/substrait/SubstraitToVeloxPlanValidator.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class SubstraitToVeloxPlanValidator {
/// Used to validate whether the computing of this Sort is supported.
bool validate(const ::substrait::SortRel& sSort);

/// Used to validate whether the computing of this Window is supported.
bool validate(const ::substrait::WindowRel& sWindow);

/// Used to validate whether the computing of this Aggregation is supported.
bool validate(const ::substrait::AggregateRel& sAgg);

Expand Down
Loading