Skip to content

Commit

Permalink
add expand operator (#65)
Browse files Browse the repository at this point in the history
Add expand operator support in Velox
  • Loading branch information
JkSelf authored and zhejiangxiaomai committed Feb 22, 2023
1 parent 9ee8152 commit 4115118
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 2 deletions.
4 changes: 3 additions & 1 deletion velox/exec/GroupId.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,13 @@ RowVectorPtr GroupId::getOutput() {
auto numGroupingKeys = mapping.size();

// Fill in grouping keys.
auto gid = 0;
for (auto i = 0; i < numGroupingKeys; ++i) {
if (mapping[i] == kMissingGroupingKey) {
// Add null column.
outputColumns[i] = BaseVector::createNullConstant(
outputType_->childAt(i), numInput, pool());
gid = 1 << (numGroupingKeys - i - 1) | gid;
} else {
outputColumns[i] = input_->childAt(mapping[i]);
}
Expand All @@ -106,7 +108,7 @@ RowVectorPtr GroupId::getOutput() {
// Add groupId column.
outputColumns[outputType_->size() - 1] =
std::make_shared<ConstantVector<int64_t>>(
pool(), numInput, false, BIGINT(), groupingSetIndex_);
pool(), numInput, false, BIGINT(), gid);

++groupingSetIndex_;
if (groupingSetIndex_ == groupingKeyMappings_.size()) {
Expand Down
60 changes: 60 additions & 0 deletions velox/exec/tests/AggregationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1402,5 +1402,65 @@ TEST_F(AggregationTest, preGroupedAggregationWithSpilling) {
OperatorTestBase::deleteTaskAndCheckSpillDirectory(task);
}

TEST_F(AggregationTest, groupingSetsOutput) {
vector_size_t size = 1'000;
auto data = makeRowVector(
{"k1", "k2", "a", "b"},
{
makeFlatVector<int64_t>(size, [](auto row) { return row % 11; }),
makeFlatVector<int64_t>(size, [](auto row) { return row % 17; }),
makeFlatVector<int64_t>(size, [](auto row) { return row; }),
makeFlatVector<StringView>(
size,
[](auto row) { return StringView(std::string(row % 12, 'x')); }),
});

createDuckDbTable({data});

core::PlanNodePtr reversedOrderGroupIdNode;
core::PlanNodePtr orderGroupIdNode;
auto reversedOrderPlan =
PlanBuilder()
.values({data})
.groupId({{"k2", "k1"}, {}}, {"a", "b"})
.capturePlanNode(reversedOrderGroupIdNode)
.singleAggregation(
{"k2", "k1", "group_id"},
{"count(1) as count_1", "sum(a) as sum_a", "max(b) as max_b"})
.project({"k1", "k2", "count_1", "sum_a", "max_b"})
.planNode();

auto orderPlan =
PlanBuilder()
.values({data})
.groupId({{"k1", "k2"}, {}}, {"a", "b"})
.capturePlanNode(orderGroupIdNode)
.singleAggregation(
{"k1", "k2", "group_id"},
{"count(1) as count_1", "sum(a) as sum_a", "max(b) as max_b"})
.project({"k1", "k2", "count_1", "sum_a", "max_b"})
.planNode();

auto reversedOrderExpectedRowType =
ROW({"k2", "k1", "a", "b", "group_id"},
{BIGINT(), BIGINT(), BIGINT(), VARCHAR(), BIGINT()});
auto orderExpectedRowType =
ROW({"k1", "k2", "a", "b", "group_id"},
{BIGINT(), BIGINT(), BIGINT(), VARCHAR(), BIGINT()});
ASSERT_EQ(
*reversedOrderGroupIdNode->outputType(), *reversedOrderExpectedRowType);
ASSERT_EQ(*orderGroupIdNode->outputType(), *orderExpectedRowType);

CursorParameters orderParams;
orderParams.planNode = orderPlan;
auto orderResult = readCursor(orderParams, [](Task*) {});

CursorParameters reversedOrderParams;
reversedOrderParams.planNode = reversedOrderPlan;
auto reversedOrderResult = readCursor(reversedOrderParams, [](Task*) {});

assertEqualResults(orderResult.second, reversedOrderResult.second);
}

} // namespace
} // namespace facebook::velox::exec::test
75 changes: 75 additions & 0 deletions velox/substrait/SubstraitToVeloxPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,78 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
}
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::ExpandRel& expandRel) {
core::PlanNodePtr childNode;
if (expandRel.has_input()) {
childNode = toVeloxPlan(expandRel.input());
} else {
VELOX_FAIL("Child Rel is expected in ExpandRel.");
}

const auto& inputType = childNode->outputType();

std::vector<std::vector<core::FieldAccessTypedExprPtr>> groupingSetExprs;
groupingSetExprs.reserve(expandRel.groupings_size());

for (const auto& grouping : expandRel.groupings()) {
std::vector<core::FieldAccessTypedExprPtr> groupingExprs;
groupingExprs.reserve(grouping.groupsets_expressions_size());

for (const auto& groupingExpr : grouping.groupsets_expressions()) {
auto expression =
exprConverter_->toVeloxExpr(groupingExpr.selection(), inputType);
auto expr_field =
dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
VELOX_CHECK(
expr_field != nullptr,
" the group set key in Expand Operator only support field")

groupingExprs.emplace_back(
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
expression));
}
groupingSetExprs.emplace_back(groupingExprs);
}

std::vector<core::GroupIdNode::GroupingKeyInfo> groupingKeyInfos;
std::set<std::string> names;
auto index = 0;
for (const auto& groupingSet : groupingSetExprs) {
for (const auto& groupingKey : groupingSet) {
if (names.find(groupingKey->name()) == names.end()) {
core::GroupIdNode::GroupingKeyInfo keyInfos;
keyInfos.output = groupingKey->name();
keyInfos.input = groupingKey;
groupingKeyInfos.push_back(keyInfos);
}
names.insert(groupingKey->name());
}
}

std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> aggExprs;

for (const auto& aggExpr : expandRel.aggregate_expressions()) {
auto expression = exprConverter_->toVeloxExpr(aggExpr, inputType);
auto expr_field =
dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
VELOX_CHECK(
expr_field != nullptr,
" the agg key in Expand Operator only support field");
auto filed =
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(expression);
aggExprs.emplace_back(filed);
}

return std::make_shared<core::GroupIdNode>(
nextPlanNodeId(),
groupingSetExprs,
std::move(groupingKeyInfos),
aggExprs,
std::move(expandRel.group_name()),
childNode);
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::SortRel& sortRel) {
auto childNode = convertSingleInput<::substrait::SortRel>(sortRel);
Expand Down Expand Up @@ -860,6 +932,9 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
if (rel.has_sort()) {
return toVeloxPlan(rel.sort());
}
if (sRel.has_expand()) {
return toVeloxPlan(sRel.expand());
}
VELOX_NYI("Substrait conversion not supported for Rel.");
}

Expand Down
2 changes: 2 additions & 0 deletions velox/substrait/SubstraitToVeloxPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class SubstraitVeloxPlanConverter {
memory::MemoryPool* pool,
bool validationMode = false)
: pool_(pool), validationMode_(validationMode) {}
/// Used to convert Substrait ExpandRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::ExpandRel& sExpand);

/// Used to convert Substrait JoinRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::JoinRel& sJoin);
Expand Down
63 changes: 63 additions & 0 deletions velox/substrait/SubstraitToVeloxPlanValidator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,66 @@ bool SubstraitToVeloxPlanValidator::validateInputTypes(
return true;
}

bool SubstraitToVeloxPlanValidator::validate(
const ::substrait::ExpandRel& sExpand) {
if (sExpand.has_input() && !validate(sExpand.input())) {
return false;
}
// Get and validate the input types from extension.
if (!sExpand.has_advanced_extension()) {
std::cout << "Input types are expected in ExpandRel." << std::endl;
return false;
}
const auto& extension = sExpand.advanced_extension();
std::vector<TypePtr> types;
if (!validateInputTypes(extension, types)) {
std::cout << "Validation failed for input types in ExpandRel." << std::endl;
return false;
}

int32_t inputPlanNodeId = 0;
std::vector<std::string> names;
names.reserve(types.size());
for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx));
}
auto rowType = std::make_shared<RowType>(std::move(names), std::move(types));

// Validate the expand agg expressions.
const auto& aggExprs = sExpand.aggregate_expressions();
std::vector<std::shared_ptr<const core::ITypedExpr>> expressions;
expressions.reserve(aggExprs.size());

try {
for (const auto& expr : aggExprs) {
expressions.emplace_back(exprConverter_->toVeloxExpr(expr, rowType));
}
// Try to compile the expressions. If there is any unregistred funciton or
// mismatched type, exception will be thrown.
exec::ExprSet exprSet(std::move(expressions), execCtx_);
} catch (const VeloxException& err) {
std::cout << "Validation failed for agg expression in ExpandRel due to:"
<< err.message() << std::endl;
return false;
}

// Validate groupings.
for (const auto& grouping : sExpand.groupings()) {
for (const auto& groupingExpr : grouping.groupsets_expressions()) {
const auto& typeCase = groupingExpr.rex_type_case();
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kSelection:
break;
default:
std::cout << "Only field is supported in groupings." << std::endl;
return false;
}
}
}

return true;
}

bool SubstraitToVeloxPlanValidator::validate(
const ::substrait::SortRel& sSort) {
if (sSort.has_input() && !validate(sSort.input())) {
Expand Down Expand Up @@ -370,6 +430,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Rel& sRel) {
if (sRel.has_sort()) {
return validate(sRel.sort());
}
if (sRel.has_expand()) {
return validate(sRel.expand());
}
return false;
}

Expand Down
3 changes: 3 additions & 0 deletions velox/substrait/SubstraitToVeloxPlanValidator.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class SubstraitToVeloxPlanValidator {
core::ExecCtx* execCtx)
: pool_(pool), execCtx_(execCtx) {}

/// Used to validate whether the computing of this Sort is supported.
bool validate(const ::substrait::ExpandRel& sExpand);

/// Used to validate whether the computing of this Sort is supported.
bool validate(const ::substrait::SortRel& sSort);

Expand Down
20 changes: 19 additions & 1 deletion velox/substrait/proto/substrait/algebra.proto
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,24 @@ message ExchangeRel {
}
}

message ExpandRel {
RelCommon common = 1;
Rel input = 2;

repeated Expression aggregate_expressions = 3;

// A list of expression grouping that the aggregation measured should be calculated for.
repeated GroupSets groupings = 4;

message GroupSets {
repeated Expression groupSets_expressions = 1;
}

string group_name = 5;

substrait.extensions.AdvancedExtension advanced_extension = 10;
}

// A relation with output field names.
//
// This is for use at the root of a `Rel` tree.
Expand All @@ -367,10 +385,10 @@ message Rel {
ExtensionMultiRel extension_multi = 10;
ExtensionLeafRel extension_leaf = 11;
CrossRel cross = 12;

//Physical relations
HashJoinRel hash_join = 13;
MergeJoinRel merge_join = 14;
ExpandRel expand = 15;
}
}

Expand Down

0 comments on commit 4115118

Please sign in to comment.