Skip to content

Commit

Permalink
add expand operator (#65)
Browse files Browse the repository at this point in the history
Add expand operator support in Velox
  • Loading branch information
JkSelf authored Nov 11, 2022
1 parent 658199a commit 76fd01f
Show file tree
Hide file tree
Showing 11 changed files with 269 additions and 25 deletions.
18 changes: 8 additions & 10 deletions velox/core/PlanNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,25 +154,23 @@ void AggregationNode::addDetails(std::stringstream& stream) const {

namespace {
RowTypePtr getGroupIdOutputType(
const std::map<std::string, FieldAccessTypedExprPtr>&
outputGroupingKeyNames,
const std::vector<GroupIdNode::GroupingKeyInfo>& groupingKeyInfos,
const std::vector<FieldAccessTypedExprPtr>& aggregationInputs,
const std::string& groupIdName) {
// Grouping keys come first, followed by aggregation inputs and groupId
// column.

auto numOutputs =
outputGroupingKeyNames.size() + aggregationInputs.size() + 1;
auto numOutputs = groupingKeyInfos.size() + aggregationInputs.size() + 1;

std::vector<std::string> names;
std::vector<TypePtr> types;

names.reserve(numOutputs);
types.reserve(numOutputs);

for (const auto& [name, groupingKey] : outputGroupingKeyNames) {
names.push_back(name);
types.push_back(groupingKey->type());
for (const auto& groupingKeyInfo : groupingKeyInfos) {
names.push_back(groupingKeyInfo.output);
types.push_back(groupingKeyInfo.input->type());
}

for (const auto& input : aggregationInputs) {
Expand All @@ -190,18 +188,18 @@ RowTypePtr getGroupIdOutputType(
GroupIdNode::GroupIdNode(
PlanNodeId id,
std::vector<std::vector<FieldAccessTypedExprPtr>> groupingSets,
std::map<std::string, FieldAccessTypedExprPtr> outputGroupingKeyNames,
std::vector<GroupIdNode::GroupingKeyInfo> groupingKeyInfos,
std::vector<FieldAccessTypedExprPtr> aggregationInputs,
std::string groupIdName,
PlanNodePtr source)
: PlanNode(std::move(id)),
sources_{source},
outputType_(getGroupIdOutputType(
outputGroupingKeyNames,
groupingKeyInfos,
aggregationInputs,
groupIdName)),
groupingSets_(std::move(groupingSets)),
outputGroupingKeyNames_(std::move(outputGroupingKeyNames)),
groupingKeyInfos_(std::move(groupingKeyInfos)),
aggregationInputs_(std::move(aggregationInputs)),
groupIdName_(std::move(groupIdName)) {
VELOX_CHECK_GE(
Expand Down
19 changes: 13 additions & 6 deletions velox/core/PlanNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -608,10 +608,18 @@ inline std::string mapAggregationStepToName(const AggregationNode::Step& step) {
/// The rest of the grouping key columns are filled in with nulls.
class GroupIdNode : public PlanNode {
public:
struct GroupingKeyInfo {
// The name to use in the output.
std::string output;
// The input field.
FieldAccessTypedExprPtr input;
};

/// @param id Plan node ID.
/// @param groupingSets A list of grouping key sets. Grouping keys within the
/// set must be unique, but grouping keys across sets may repeat.
/// @param outputGroupingKeyNames Output names for the grouping keys.
/// @param groupingKeyInfos The names and order of the grouping keys in the
/// output.
/// @param aggregationInputs Columns that contain inputs to the aggregate
/// functions.
/// @param groupIdName Name of the column that will contain the grouping set
Expand All @@ -620,7 +628,7 @@ class GroupIdNode : public PlanNode {
GroupIdNode(
PlanNodeId id,
std::vector<std::vector<FieldAccessTypedExprPtr>> groupingSets,
std::map<std::string, FieldAccessTypedExprPtr> outputGroupingKeyNames,
std::vector<GroupingKeyInfo> groupingKeyInfos,
std::vector<FieldAccessTypedExprPtr> aggregationInputs,
std::string groupIdName,
PlanNodePtr source);
Expand All @@ -638,9 +646,8 @@ class GroupIdNode : public PlanNode {
return groupingSets_;
}

const std::map<std::string, FieldAccessTypedExprPtr>& outputGroupingKeyNames()
const {
return outputGroupingKeyNames_;
const std::vector<GroupingKeyInfo>& groupingKeyInfos() const {
return groupingKeyInfos_;
}

const std::vector<FieldAccessTypedExprPtr>& aggregationInputs() const {
Expand All @@ -665,7 +672,7 @@ class GroupIdNode : public PlanNode {
const std::vector<PlanNodePtr> sources_;
const RowTypePtr outputType_;
const std::vector<std::vector<FieldAccessTypedExprPtr>> groupingSets_;
const std::map<std::string, FieldAccessTypedExprPtr> outputGroupingKeyNames_;
const std::vector<GroupingKeyInfo> groupingKeyInfos_;
const std::vector<FieldAccessTypedExprPtr> aggregationInputs_;
const std::string groupIdName_;
};
Expand Down
10 changes: 6 additions & 4 deletions velox/exec/GroupId.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ GroupId::GroupId(

std::unordered_map<std::string, column_index_t>
inputToOutputGroupingKeyMapping;
for (const auto& [output, input] : groupIdNode->outputGroupingKeyNames()) {
inputToOutputGroupingKeyMapping[input->name()] =
outputType_->getChildIdx(output);
for (const auto& groupingKeyInfo : groupIdNode->groupingKeyInfos()) {
inputToOutputGroupingKeyMapping[groupingKeyInfo.input->name()] =
outputType_->getChildIdx(groupingKeyInfo.output);
}

auto numGroupingSets = groupIdNode->groupingSets().size();
Expand Down Expand Up @@ -88,11 +88,13 @@ RowVectorPtr GroupId::getOutput() {
auto numGroupingKeys = mapping.size();

// Fill in grouping keys.
auto gid = 0;
for (auto i = 0; i < numGroupingKeys; ++i) {
if (mapping[i] == kMissingGroupingKey) {
// Add null column.
outputColumns[i] = BaseVector::createNullConstant(
outputType_->childAt(i), numInput, pool());
gid = 1 << (numGroupingKeys - i - 1) | gid;
} else {
outputColumns[i] = input_->childAt(mapping[i]);
}
Expand All @@ -105,7 +107,7 @@ RowVectorPtr GroupId::getOutput() {

// Add groupId column.
outputColumns[outputType_->size() - 1] =
BaseVector::createConstant((int64_t)groupingSetIndex_, numInput, pool());
BaseVector::createConstant((int64_t)gid, numInput, pool());

++groupingSetIndex_;
if (groupingSetIndex_ == groupingKeyMappings_.size()) {
Expand Down
60 changes: 60 additions & 0 deletions velox/exec/tests/AggregationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1291,5 +1291,65 @@ TEST_F(AggregationTest, outputBatchSizeCheckWithSpill) {
}
}

TEST_F(AggregationTest, groupingSetsOutput) {
vector_size_t size = 1'000;
auto data = makeRowVector(
{"k1", "k2", "a", "b"},
{
makeFlatVector<int64_t>(size, [](auto row) { return row % 11; }),
makeFlatVector<int64_t>(size, [](auto row) { return row % 17; }),
makeFlatVector<int64_t>(size, [](auto row) { return row; }),
makeFlatVector<StringView>(
size,
[](auto row) { return StringView(std::string(row % 12, 'x')); }),
});

createDuckDbTable({data});

core::PlanNodePtr reversedOrderGroupIdNode;
core::PlanNodePtr orderGroupIdNode;
auto reversedOrderPlan =
PlanBuilder()
.values({data})
.groupId({{"k2", "k1"}, {}}, {"a", "b"})
.capturePlanNode(reversedOrderGroupIdNode)
.singleAggregation(
{"k2", "k1", "group_id"},
{"count(1) as count_1", "sum(a) as sum_a", "max(b) as max_b"})
.project({"k1", "k2", "count_1", "sum_a", "max_b"})
.planNode();

auto orderPlan =
PlanBuilder()
.values({data})
.groupId({{"k1", "k2"}, {}}, {"a", "b"})
.capturePlanNode(orderGroupIdNode)
.singleAggregation(
{"k1", "k2", "group_id"},
{"count(1) as count_1", "sum(a) as sum_a", "max(b) as max_b"})
.project({"k1", "k2", "count_1", "sum_a", "max_b"})
.planNode();

auto reversedOrderExpectedRowType =
ROW({"k2", "k1", "a", "b", "group_id"},
{BIGINT(), BIGINT(), BIGINT(), VARCHAR(), BIGINT()});
auto orderExpectedRowType =
ROW({"k1", "k2", "a", "b", "group_id"},
{BIGINT(), BIGINT(), BIGINT(), VARCHAR(), BIGINT()});
ASSERT_EQ(
*reversedOrderGroupIdNode->outputType(), *reversedOrderExpectedRowType);
ASSERT_EQ(*orderGroupIdNode->outputType(), *orderExpectedRowType);

CursorParameters orderParams;
orderParams.planNode = orderPlan;
auto orderResult = readCursor(orderParams, [](Task*) {});

CursorParameters reversedOrderParams;
reversedOrderParams.planNode = reversedOrderPlan;
auto reversedOrderResult = readCursor(reversedOrderParams, [](Task*) {});

assertEqualResults(orderResult.second, reversedOrderResult.second);
}

} // namespace
} // namespace facebook::velox::exec::test
14 changes: 11 additions & 3 deletions velox/exec/tests/utils/PlanBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,17 +634,25 @@ PlanBuilder& PlanBuilder::groupId(
groupingSetExprs.push_back(fields(groupingSet));
}

std::map<std::string, core::FieldAccessTypedExprPtr> outputGroupingKeyNames;
std::vector<core::GroupIdNode::GroupingKeyInfo> groupingKeyInfos;
std::set<std::string> names;
auto index = 0;
for (const auto& groupingSet : groupingSetExprs) {
for (const auto& groupingKey : groupingSet) {
outputGroupingKeyNames[groupingKey->name()] = groupingKey;
if (names.find(groupingKey->name()) == names.end()) {
core::GroupIdNode::GroupingKeyInfo keyInfos;
keyInfos.output = groupingKey->name();
keyInfos.input = groupingKey;
groupingKeyInfos.push_back(keyInfos);
}
names.insert(groupingKey->name());
}
}

planNode_ = std::make_shared<core::GroupIdNode>(
nextPlanNodeId(),
groupingSetExprs,
std::move(outputGroupingKeyNames),
std::move(groupingKeyInfos),
fields(aggregationInputs),
std::move(groupIdName),
planNode_);
Expand Down
9 changes: 8 additions & 1 deletion velox/exec/tests/utils/PlanBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ class PlanBuilder {
const std::vector<TypePtr>& resultTypes = {});

/// Add a GroupIdNode using the specified grouping sets, aggregation inputs
/// and a groupId column name.
/// and a groupId column name. And create GroupIdNode plan node with grouping
/// keys appearing in the output in the order they appear in 'groupingSets'.
PlanBuilder& groupId(
const std::vector<std::vector<std::string>>& groupingSets,
const std::vector<std::string>& aggregationInputs,
Expand Down Expand Up @@ -663,6 +664,12 @@ class PlanBuilder {
return *this;
}

PlanBuilder& capturePlanNode(core::PlanNodePtr& planNode) {
VELOX_CHECK_NOT_NULL(planNode_);
planNode = planNode_;
return *this;
}

/// Return the latest plan node, e.g. the root node of the plan tree.
const core::PlanNodePtr& planNode() const {
return planNode_;
Expand Down
77 changes: 76 additions & 1 deletion velox/substrait/SubstraitToVeloxPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,14 +398,86 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
projectNames.emplace_back(subParser_->makeNodeName(planNodeId_, colIdx));
colIdx += 1;
}

return std::make_shared<core::ProjectNode>(
nextPlanNodeId(),
std::move(projectNames),
std::move(expressions),
childNode);
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::ExpandRel& expandRel) {
core::PlanNodePtr childNode;
if (expandRel.has_input()) {
childNode = toVeloxPlan(expandRel.input());
} else {
VELOX_FAIL("Child Rel is expected in ExpandRel.");
}

const auto& inputType = childNode->outputType();

std::vector<std::vector<core::FieldAccessTypedExprPtr>> groupingSetExprs;
groupingSetExprs.reserve(expandRel.groupings_size());

for (const auto& grouping : expandRel.groupings()) {
std::vector<core::FieldAccessTypedExprPtr> groupingExprs;
groupingExprs.reserve(grouping.groupsets_expressions_size());

for (const auto& groupingExpr : grouping.groupsets_expressions()) {
auto expression =
exprConverter_->toVeloxExpr(groupingExpr.selection(), inputType);
auto expr_field =
dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
VELOX_CHECK(
expr_field != nullptr,
" the group set key in Expand Operator only support field")

groupingExprs.emplace_back(
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
expression));
}
groupingSetExprs.emplace_back(groupingExprs);
}

std::vector<core::GroupIdNode::GroupingKeyInfo> groupingKeyInfos;
std::set<std::string> names;
auto index = 0;
for (const auto& groupingSet : groupingSetExprs) {
for (const auto& groupingKey : groupingSet) {
if (names.find(groupingKey->name()) == names.end()) {
core::GroupIdNode::GroupingKeyInfo keyInfos;
keyInfos.output = groupingKey->name();
keyInfos.input = groupingKey;
groupingKeyInfos.push_back(keyInfos);
}
names.insert(groupingKey->name());
}
}

std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> aggExprs;

for (const auto& aggExpr : expandRel.aggregate_expressions()) {
auto expression = exprConverter_->toVeloxExpr(aggExpr, inputType);
auto expr_field =
dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
VELOX_CHECK(
expr_field != nullptr,
" the agg key in Expand Operator only support field");
auto filed = std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
expression);
aggExprs.emplace_back( filed);
}

return std::make_shared<core::GroupIdNode>(
nextPlanNodeId(),
groupingSetExprs,
std::move(groupingKeyInfos),
aggExprs,
std::move(expandRel.group_name()),
childNode);
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::SortRel& sortRel) {
core::PlanNodePtr childNode;
Expand Down Expand Up @@ -726,6 +798,9 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
if (sRel.has_sort()) {
return toVeloxPlan(sRel.sort());
}
if (sRel.has_expand()) {
return toVeloxPlan(sRel.expand());
}
VELOX_NYI("Substrait conversion not supported for Rel.");
}

Expand Down
2 changes: 2 additions & 0 deletions velox/substrait/SubstraitToVeloxPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class SubstraitVeloxPlanConverter {
memory::MemoryPool* pool,
bool validationMode = false)
: pool_(pool), validationMode_(validationMode) {}
/// Used to convert Substrait ExpandRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::ExpandRel& sExpand);

/// Used to convert Substrait SortRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::SortRel& sSort);
Expand Down
Loading

0 comments on commit 76fd01f

Please sign in to comment.