Skip to content

Commit

Permalink
Add companionStep
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Apr 23, 2024
1 parent f8ff4b0 commit 569ac66
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 64 deletions.
3 changes: 2 additions & 1 deletion velox/docs/develop/aggregate-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ A simple aggregation function is implemented as a class as the following.
FunctionState& state,
const std::vector<TypePtr>& rawInputTypes,
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) {
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> companionStep) {
state.resultType = resultType;
}

Expand Down
4 changes: 3 additions & 1 deletion velox/exec/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ class Aggregate {
core::AggregationNode::Step step,
const std::vector<TypePtr>& rawInputType,
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) {}
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> companionStep = std::nullopt) {
}

// Initializes null flags and accumulators for newly encountered groups. This
// function should be called only once for each group.
Expand Down
23 changes: 14 additions & 9 deletions velox/exec/AggregateCompanionAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,17 @@ void AggregateCompanionFunctionBase::extractAccumulators(
}

void AggregateCompanionAdapter::PartialFunction::initialize(
core::AggregationNode::Step /*step*/,
core::AggregationNode::Step step,
const std::vector<TypePtr>& rawInputType,
const facebook::velox::TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) {
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> /*companionStep*/) {
fn_->initialize(
core::AggregationNode::Step::kPartial,
step,
rawInputType,
resultType,
constantInputs);
constantInputs,
core::AggregationNode::Step::kPartial);
}

void AggregateCompanionAdapter::PartialFunction::extractValues(
Expand All @@ -144,15 +146,17 @@ void AggregateCompanionAdapter::PartialFunction::extractValues(
}

void AggregateCompanionAdapter::MergeFunction::initialize(
core::AggregationNode::Step /*step*/,
core::AggregationNode::Step step,
const std::vector<TypePtr>& rawInputType,
const facebook::velox::TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) {
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> /*companionStep*/) {
fn_->initialize(
core::AggregationNode::Step::kIntermediate,
step,
rawInputType,
resultType,
constantInputs);
constantInputs,
core::AggregationNode::Step::kIntermediate);
}

void AggregateCompanionAdapter::MergeFunction::addRawInput(
Expand Down Expand Up @@ -270,7 +274,8 @@ void AggregateCompanionAdapter::ExtractFunction::apply(
core::AggregationNode::Step::kFinal,
rawInputTypes,
outputType,
constantInputs);
constantInputs,
core::AggregationNode::Step::kIntermediate);
fn_->initializeNewGroups(groups, allSelectedRange);
fn_->enableValidateIntermediateInputs();
fn_->addIntermediateResults(groups, rows, args, false);
Expand Down
6 changes: 4 additions & 2 deletions velox/exec/AggregateCompanionAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ struct AggregateCompanionAdapter {
core::AggregationNode::Step step,
const std::vector<TypePtr>& rawInputType,
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) override;
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> companionStep) override;

void extractValues(char** groups, int32_t numGroups, VectorPtr* result)
override;
Expand All @@ -120,7 +121,8 @@ struct AggregateCompanionAdapter {
core::AggregationNode::Step step,
const std::vector<TypePtr>& rawInputType,
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) override;
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> companionStep) override;

void addRawInput(
char** groups,
Expand Down
11 changes: 9 additions & 2 deletions velox/exec/SimpleAggregateAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,16 @@ class SimpleAggregateAdapter : public Aggregate {
core::AggregationNode::Step step,
const std::vector<TypePtr>& rawInputTypes,
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) override {
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> companionStep) override {
if constexpr (support_initialize_function_state_) {
FUNC::initialize(state_, step, rawInputTypes, resultType, constantInputs);
FUNC::initialize(
state_,
step,
rawInputTypes,
resultType,
constantInputs,
companionStep);
}
}

Expand Down
74 changes: 25 additions & 49 deletions velox/exec/tests/SimpleAggregateAdapterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ class FunctionStateTestAggregate {
std::vector<TypePtr> rawInputTypes;
TypePtr resultType;
std::vector<VectorPtr> constantInputs;
core::AggregationNode::Step companionStep;
};

static void checkConstantInputs(
Expand All @@ -517,19 +518,29 @@ class FunctionStateTestAggregate {
core::AggregationNode::Step step,
const std::vector<TypePtr>& rawInputTypes,
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) {
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> companionStep) {
auto expectedRawInputTypes = {BIGINT(), BIGINT()};
auto expectedIntermediateType = ROW({BIGINT(), DOUBLE()});

if constexpr (testCompanion) {
if (step == core::AggregationNode::Step::kPartial) {
VELOX_CHECK(companionStep.has_value());
if (companionStep.value() == core::AggregationNode::Step::kPartial) {
VELOX_CHECK(std::equal(
rawInputTypes.begin(),
rawInputTypes.end(),
expectedRawInputTypes.begin(),
expectedRawInputTypes.end()));
checkConstantInputs(constantInputs);
} else if (step == core::AggregationNode::Step::kIntermediate) {
if (step == core::AggregationNode::Step::kPartial ||
step == core::AggregationNode::Step::kSingle) {
// Only check constant inputs in partial and single step.
checkConstantInputs(constantInputs);
} else {
VELOX_CHECK_EQ(constantInputs.size(), 1);
VELOX_CHECK_NULL(constantInputs[0]);
}
} else if (
companionStep.value() == core::AggregationNode::Step::kIntermediate) {
VELOX_CHECK_EQ(rawInputTypes.size(), 1);
VELOX_CHECK(rawInputTypes[0]->equivalent(*expectedIntermediateType));

Expand All @@ -546,6 +557,7 @@ class FunctionStateTestAggregate {
expectedRawInputTypes.end()));
if (step == core::AggregationNode::Step::kPartial ||
step == core::AggregationNode::Step::kSingle) {
// Only check constant inputs in partial and single step.
checkConstantInputs(constantInputs);
} else {
VELOX_CHECK_EQ(constantInputs.size(), 1);
Expand Down Expand Up @@ -665,51 +677,15 @@ TEST_F(SimpleFunctionStateAggregationTest, aggregate) {
{},
{"simple_function_state_agg_main(c0, 1)"},
{expected});
}

TEST_F(SimpleFunctionStateAggregationTest, companionAggregate) {
auto inputVectors = makeRowVector({makeFlatVector<int64_t>({1, 2, 3, 4})});
std::vector<int64_t> accSum = {10};
std::vector<double> accCount = {4.0};
auto intermediateExpected = makeRowVector({
makeRowVector({
makeFlatVector<int64_t>(accSum),
makeFlatVector<double>(accCount),
}),
});
std::vector<double> finalResult = {2.5};
auto finalExpected = makeRowVector({makeFlatVector<double>(finalResult)});

AssertQueryBuilder(
PlanBuilder()
.values({inputVectors})
.singleAggregation(
{}, {"simple_function_state_agg_companion_partial(c0, 1)"})
.planNode())
.assertResults(intermediateExpected);

inputVectors = makeRowVector({
makeRowVector({
makeFlatVector<int64_t>({1, 2, 3, 4}),
makeFlatVector<double>({1.0, 1.0, 1.0, 1.0}),
}),
});

AssertQueryBuilder(
PlanBuilder()
.values({inputVectors})
.singleAggregation(
{}, {"simple_function_state_agg_companion_merge(c0)"})
.planNode())
.assertResults(intermediateExpected);

AssertQueryBuilder(
PlanBuilder()
.values({inputVectors})
.singleAggregation(
{}, {"simple_function_state_agg_companion_merge_extract(c0)"})
.planNode())
.assertResults(finalExpected);
testAggregationsWithCompanion(
{inputVectors},
[](auto& /*builder*/) {},
{},
{"simple_function_state_agg_companion(c0, 1)"},
{{BIGINT(), BIGINT()}},
{},
{expected},
{});
}

class SimpleFunctionStateWindowTest : public WindowTestBase {
Expand Down

0 comments on commit 569ac66

Please sign in to comment.