diff --git a/velox/docs/develop/aggregate-functions.rst b/velox/docs/develop/aggregate-functions.rst index c11a5dabc211..e872ae5dda94 100644 --- a/velox/docs/develop/aggregate-functions.rst +++ b/velox/docs/develop/aggregate-functions.rst @@ -152,13 +152,33 @@ A simple aggregation function is implemented as a class as the following. using IntermediateType = Array>; using OutputType = Array>; + // Define a struct for function-level states. Even if the aggregation function + // doesn't use function-level states, it is still necessary to define an empty + // FunctionState struct. + struct FunctionState { + // Optional. + TypePtr resultType; + }; + + // Optional. Defined only when the aggregation function needs to use function-level states. + // This method is called once when the aggregation function is created. + static void initialize( + core::AggregationNode::Step step, + FunctionState& state, + const std::vector& rawInputTypes, + const TypePtr& resultType, + const std::vector& constantInputs, + std::optional companionStep) { + state.resultType = resultType; + } + // Optional. Default is true. static constexpr bool default_null_behavior_ = false; // Optional. static bool toIntermediate( - exec::out_type>>& out, - exec::optional_arg_type> in); + exec::out_type>>& out, + exec::optional_arg_type> in); struct AccumulatorType { ... }; }; @@ -169,6 +189,15 @@ function's argument type(s) wrapped in a Row<> even if the function only takes one argument. This is needed for the SimpleAggregateAdapter to parse input types for arbitrary aggregation functions properly. +A FunctionState struct needs to be declared in the simple aggregation function +class. FunctionState is initialized once when the aggregation function is +created and used at every row when adding inputs to accumulators or extracting +values from accumulators. For example, if the aggregation function needs to get +the result type or the raw input type of the aggregaiton function, the author +can hold them in the FunctionState struct, and initialize them in the +initialize() method. If the aggregation function does not require the use of +FunctionState, it is necessary to declare an empty FunctionState struct. + The author can define an optional flag `default_null_behavior_` indicating whether the aggregation function has default-null behavior. This flag is true by default. Next, the class can have an optional method `toIntermediate()` @@ -257,17 +286,21 @@ For aggregaiton functions of default-null behavior, the author defines an // Optional. Default is false. static constexpr bool is_aligned_ = true; - explicit AccumulatorType(HashStringAllocator* allocator); + explicit AccumulatorType(HashStringAllocator* allocator, const FunctionState& state); - void addInput(HashStringAllocator* allocator, exec::arg_type value1, ...); + void addInput( + HashStringAllocator* allocator, + exec::arg_type value1, ..., + const FunctionState& state); void combine( HashStringAllocator* allocator, - exec::arg_type other); + exec::arg_type other, + const FunctionState& state); - bool writeIntermediateResult(exec::out_type& out); + bool writeIntermediateResult(exec::out_type& out, const FunctionState& state); - bool writeFinalResult(exec::out_type& out); + bool writeFinalResult(exec::out_type& out, const FunctionState& state); // Optional. Called during destruction. void destroy(HashStringAllocator* allocator); @@ -296,7 +329,8 @@ addInput This method adds raw input values to *this* accumulator. It receives a `HashStringAllocator*` followed by `exec::arg_type`-typed values, one for -each argument type `Ti` wrapped in InputType. +each argument type `Ti` wrapped in InputType. `const FunctionState&` hold the +function-level variables. With default-null behavior, raw-input rows where at least one column is null are ignored before `addInput` is called. After `addInput` is called, *this* @@ -306,31 +340,32 @@ combine """"""" This method adds an input intermediate state to *this* accumulator. It receives -a `HashStringAllocator*` and one `exec::arg_type` value. With -default-null behavior, nulls among the input intermediate states are ignored -before `combine` is called. After `combine` is called, *this* accumulator is -assumed to be non-null. +a `HashStringAllocator*` and one `exec::arg_type` value. +`const FunctionState&` hold the function-level variables. With default-null +behavior, nulls among the input intermediate states are ignored before `combine` +is called. After `combine` is called, *this* accumulator is assumed to be non-null. writeIntermediateResult """"""""""""""""""""""" This method writes *this* accumulator out to an intermediate state vector. It -has an out-parameter of the type `exec::out_type&`. This -method returns true if it writes a non-null value to `out`, or returns false -meaning a null should be written to the intermediate state vector. Accumulators -that are nulls (i.e., no value has been added to them) automatically become -nulls in the intermediate state vector without `writeIntermediateResult` being -called. +has an out-parameter of the type `exec::out_type&`. +`const FunctionState&` hold the function-level variables. This method returns +true if it writes a non-null value to `out`, or returns false meaning a null +should be written to the intermediate state vector. Accumulators that are +nulls (i.e., no value has been added to them) automatically become nulls in +the intermediate state vector without `writeIntermediateResult` being called. writeFinalResult """""""""""""""" This method writes *this* accumulator out to a final result vector. It -has an out-parameter of the type `exec::out_type&`. This -method returns true if it writes a non-null value to `out`, or returns false -meaning a null should be written to the final result vector. Accumulators -that are nulls (i.e., no value has been added to them) automatically become -nulls in the final result vector without `writeFinalResult` being called. +has an out-parameter of the type `exec::out_type&`. +`const FunctionState&` hold the function-level variables. This method returns +true if it writes a non-null value to `out`, or returns false meaning a null +should be written to the final result vector. Accumulators that are +nulls (i.e., no value has been added to them) automatically become nulls in the +final result vector without `writeFinalResult` being called. AccumulatorType of Non-Default-Null Behavior ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -355,15 +390,25 @@ For aggregaiton functions of non-default-null behavior, the author defines an explicit AccumulatorType(HashStringAllocator* allocator); - bool addInput(HashStringAllocator* allocator, exec::optional_arg_type value1, ...); + bool addInput( + HashStringAllocator* allocator, + exec::optional_arg_type value1, ..., + const FunctionState& state); bool combine( HashStringAllocator* allocator, - exec::optional_arg_type other); + exec::optional_arg_type other, + const FunctionState& state); - bool writeIntermediateResult(bool nonNullGroup, exec::out_type& out); + bool writeIntermediateResult( + bool nonNullGroup, + exec::out_type& out, + const FunctionState& state); - bool writeFinalResult(bool nonNullGroup, exec::out_type& out); + bool writeFinalResult( + bool nonNullGroup, + exec::out_type& out, + const FunctionState& state); // Optional. void destroy(HashStringAllocator* allocator); @@ -384,7 +429,7 @@ addInput This method receives a `HashStringAllocator*` followed by `exec::optional_arg_type` values, one for each argument type `Ti` wrapped -in InputType. +in InputType. `const FunctionState&` hold the function-level variables. This method is called on all raw-input rows even if some columns may be null. It returns a boolean meaning whether *this* accumulator is non-null after the @@ -397,26 +442,29 @@ combine """"""" This method receives a `HashStringAllocator*` and an -`exec::optional_arg_type` value. This method is called on -all intermediate states even if some are nulls. Same as `addInput`, this method -returns a boolean meaning whether *this* accumulator is non-null after the call. +`exec::optional_arg_type` value. `const FunctionState&` hold +the function-level variables.This method is called on all intermediate states +even if some are nulls. Same as `addInput`, this method returns a boolean +meaning whether *this* accumulator is non-null after the call. writeIntermediateResult """"""""""""""""""""""" This method has an out-parameter of the type `exec::out_type&` and a boolean flag `nonNullGroup` indicating whether *this* accumulator is -non-null. This method returns true if it writes a non-null value to `out`, or -return false meaning a null should be written to the intermediate state vector. +non-null. `const FunctionState&` hold the function-level variables. This method +returns true if it writes a non-null value to `out`, or return false meaning a +null should be written to the intermediate state vector. writeFinalResult """""""""""""""" This method writes *this* accumulator out to a final result vector. It has an out-parameter of the type `exec::out_type&` and a boolean flag -`nonNullGroup` indicating whether *this* accumulator is non-null. This method -returns true if it writes a non-null value to `out`, or return false meaning a -null should be written to the final result vector. +`nonNullGroup` indicating whether *this* accumulator is non-null. +`const FunctionState&` hold the function-level variables.This method returns +true if it writes a non-null value to `out`, or return false meaning a null +should be written to the final result vector. Limitations ^^^^^^^^^^^ diff --git a/velox/exec/Aggregate.h b/velox/exec/Aggregate.h index d6bc12aefcde..de3e6b529dac 100644 --- a/velox/exec/Aggregate.h +++ b/velox/exec/Aggregate.h @@ -129,6 +129,27 @@ class Aggregate { rowSizeOffset); } + // Initialize the function-level state of the simple function interface for + // UDAF. + // @param step The aggregation step. + // @param rawInputType The raw input type of the UDAF. + // @param resultType The result type of the current aggregation step. + // @param constantInputs Optional constant input values for aggregate + // function. constantInputs should be empty if there are no constant inputs, + // aligned with inputTypes if there is at least one constant input, with + // non-constant inputs represented as nullptr, and must be instances of + // ConstantVector. + // @param companionStep The step used to register aggregate companion + // functions. kPartial for partial companion function, kIntermediate for merge + // and merge extract companion function. + virtual void initialize( + core::AggregationNode::Step step, + const std::vector& rawInputType, + const TypePtr& resultType, + const std::vector& constantInputs, + std::optional companionStep = std::nullopt) { + } + // Initializes null flags and accumulators for newly encountered groups. This // function should be called only once for each group. // diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index c1a087f24d7c..5b4135428bc2 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -124,6 +124,20 @@ void AggregateCompanionFunctionBase::extractAccumulators( fn_->extractAccumulators(groups, numGroups, result); } +void AggregateCompanionAdapter::PartialFunction::initialize( + core::AggregationNode::Step step, + const std::vector& rawInputType, + const facebook::velox::TypePtr& resultType, + const std::vector& constantInputs, + std::optional /*companionStep*/) { + fn_->initialize( + step, + rawInputType, + resultType, + constantInputs, + core::AggregationNode::Step::kPartial); +} + void AggregateCompanionAdapter::PartialFunction::extractValues( char** groups, int32_t numGroups, @@ -131,6 +145,20 @@ void AggregateCompanionAdapter::PartialFunction::extractValues( fn_->extractAccumulators(groups, numGroups, result); } +void AggregateCompanionAdapter::MergeFunction::initialize( + core::AggregationNode::Step step, + const std::vector& rawInputType, + const facebook::velox::TypePtr& resultType, + const std::vector& constantInputs, + std::optional /*companionStep*/) { + fn_->initialize( + step, + rawInputType, + resultType, + constantInputs, + core::AggregationNode::Step::kIntermediate); +} + void AggregateCompanionAdapter::MergeFunction::addRawInput( char** groups, const SelectivityVector& rows, @@ -156,6 +184,20 @@ void AggregateCompanionAdapter::MergeFunction::extractValues( fn_->extractAccumulators(groups, numGroups, result); } +void AggregateCompanionAdapter::MergeExtractFunction::initialize( + core::AggregationNode::Step step, + const std::vector& rawInputType, + const facebook::velox::TypePtr& resultType, + const std::vector& constantInputs, + std::optional /*companionStep*/) { + fn_->initialize( + step, + rawInputType, + resultType, + constantInputs, + core::AggregationNode::Step::kFinal); +} + void AggregateCompanionAdapter::MergeExtractFunction::extractValues( char** groups, int32_t numGroups, @@ -229,6 +271,25 @@ void AggregateCompanionAdapter::ExtractFunction::apply( // Perform per-row aggregation. std::vector allSelectedRange; rows.applyToSelected([&](auto row) { allSelectedRange.push_back(row); }); + + // Get the raw input types. + std::vector rawInputTypes{args.size()}; + std::vector constantInputs{args.size()}; + for (auto i = 0; i < args.size(); i++) { + rawInputTypes[i] = args[i]->type(); + if (args[i]->isConstantEncoding()) { + constantInputs[i] = args[i]; + } else { + constantInputs[i] = nullptr; + } + } + + fn_->initialize( + core::AggregationNode::Step::kFinal, + rawInputTypes, + outputType, + constantInputs, + core::AggregationNode::Step::kFinal); fn_->initializeNewGroups(groups, allSelectedRange); fn_->enableValidateIntermediateInputs(); fn_->addIntermediateResults(groups, rows, args, false); diff --git a/velox/exec/AggregateCompanionAdapter.h b/velox/exec/AggregateCompanionAdapter.h index 91b7c3a7bed8..6c316a74b888 100644 --- a/velox/exec/AggregateCompanionAdapter.h +++ b/velox/exec/AggregateCompanionAdapter.h @@ -99,6 +99,13 @@ struct AggregateCompanionAdapter { const TypePtr& resultType) : AggregateCompanionFunctionBase{std::move(fn), resultType} {} + void initialize( + core::AggregationNode::Step step, + const std::vector& rawInputType, + const TypePtr& resultType, + const std::vector& constantInputs, + std::optional companionStep) override; + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override; }; @@ -110,6 +117,13 @@ struct AggregateCompanionAdapter { const TypePtr& resultType) : AggregateCompanionFunctionBase{std::move(fn), resultType} {} + void initialize( + core::AggregationNode::Step step, + const std::vector& rawInputType, + const TypePtr& resultType, + const std::vector& constantInputs, + std::optional companionStep) override; + void addRawInput( char** groups, const SelectivityVector& rows, @@ -133,6 +147,13 @@ struct AggregateCompanionAdapter { const TypePtr& resultType) : MergeFunction{std::move(fn), resultType} {} + void initialize( + core::AggregationNode::Step step, + const std::vector& rawInputType, + const TypePtr& resultType, + const std::vector& constantInputs, + std::optional companionStep) override; + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override; }; diff --git a/velox/exec/AggregateInfo.cpp b/velox/exec/AggregateInfo.cpp index 9b19ab687965..0a51b2cfbfb1 100644 --- a/velox/exec/AggregateInfo.cpp +++ b/velox/exec/AggregateInfo.cpp @@ -103,6 +103,8 @@ std::vector toAggregateInfo( aggResultType, operatorCtx.driverCtx()->queryConfig()); + info.function->initialize( + step, aggregate.rawInputTypes, aggResultType, info.constantInputs); auto lambdas = extractLambdaInputs(aggregate); if (!lambdas.empty()) { if (expressionEvaluator == nullptr) { diff --git a/velox/exec/AggregateWindow.cpp b/velox/exec/AggregateWindow.cpp index cb32bd0779c3..83650f74322e 100644 --- a/velox/exec/AggregateWindow.cpp +++ b/velox/exec/AggregateWindow.cpp @@ -45,8 +45,10 @@ class AggregateWindowFunction : public exec::WindowFunction { argTypes_.reserve(args.size()); argIndices_.reserve(args.size()); argVectors_.reserve(args.size()); + constantInputs_.reserve(args.size()); for (const auto& arg : args) { argTypes_.push_back(arg.type); + constantInputs_.push_back(arg.constantValue); if (arg.constantValue) { argIndices_.push_back(kConstantChannel); argVectors_.push_back(arg.constantValue); @@ -151,6 +153,11 @@ class AggregateWindowFunction : public exec::WindowFunction { // aggregate_ function object should be initialized. auto singleGroup = std::vector{0}; aggregate_->clear(); + aggregate_->initialize( + core::AggregationNode::Step::kSingle, + argTypes_, + resultType_, + constantInputs_); aggregate_->initializeNewGroups(&rawSingleGroupRow_, singleGroup); aggregateInitialized_ = true; } @@ -332,6 +339,11 @@ class AggregateWindowFunction : public exec::WindowFunction { // the aggregation based on the frame changes with each row. This would // require adding new APIs to the Aggregate framework. aggregate_->clear(); + aggregate_->initialize( + core::AggregationNode::Step::kSingle, + argTypes_, + resultType_, + constantInputs_); aggregate_->initializeNewGroups(&rawSingleGroupRow_, kSingleGroup); aggregateInitialized_ = true; @@ -349,6 +361,11 @@ class AggregateWindowFunction : public exec::WindowFunction { // This value is returned for rows with empty frames. void computeDefaultAggregateValue(const TypePtr& resultType) { aggregate_->clear(); + aggregate_->initialize( + core::AggregationNode::Step::kSingle, + argTypes_, + resultType, + constantInputs_); aggregate_->initializeNewGroups( &rawSingleGroupRow_, std::vector{0}); aggregateInitialized_ = true; @@ -374,6 +391,11 @@ class AggregateWindowFunction : public exec::WindowFunction { std::vector argTypes_; std::vector argIndices_; std::vector argVectors_; + // Constant input values for aggregate function. it should be empty if there + // are no constant inputs, aligned with inputTypes if there is at least one + // constant input, with non-constant inputs represented as nullptr, and must + // be instances of ConstantVector. + std::vector constantInputs_; // This is a single aggregate row needed by the aggregate function for its // computation. These values are for the row and its various components. diff --git a/velox/exec/SimpleAggregateAdapter.h b/velox/exec/SimpleAggregateAdapter.h index d92584c6d647..0826016e36e8 100644 --- a/velox/exec/SimpleAggregateAdapter.h +++ b/velox/exec/SimpleAggregateAdapter.h @@ -44,6 +44,28 @@ class SimpleAggregateAdapter : public Aggregate { explicit SimpleAggregateAdapter(TypePtr resultType) : Aggregate(std::move(resultType)) {} + // Function-level states are variables hold by a UDAF instance that are + // computed once and used at every row when adding inputs to accumulators or + // extracting values from accumulators. + typename FUNC::FunctionState state_; + + void initialize( + core::AggregationNode::Step step, + const std::vector& rawInputTypes, + const TypePtr& resultType, + const std::vector& constantInputs, + std::optional companionStep) override { + if constexpr (support_initialize_function_state_) { + FUNC::initialize( + state_, + step, + rawInputTypes, + resultType, + constantInputs, + companionStep); + } + } + // Assume most aggregate functions have fixed-size accumulators. Functions // that // have non-fixed-size accumulators should overwrite `is_fixed_size_` in their @@ -145,6 +167,19 @@ class SimpleAggregateAdapter : public Aggregate { struct support_to_intermediate> : std::true_type {}; + // Whether the function defines its initialize() method or not. If it is + // defined, SimpleAggregateAdapter::supportInitializeFunctionState() returns + // true. Otherwise, SimpleAggregateAdapter::supportInitializeFunctionState() + // returns false and SimpleAggregateAdapter::initialize() will not initialize + // the UDAF's FunctionState. + template + struct support_initialize_function_state : std::false_type {}; + + template + struct support_initialize_function_state< + T, + std::void_t> : std::true_type {}; + // Whether the accumulator requires aligned access. If it is defined, // SimpleAggregateAdapter::accumulatorAlignmentSize() returns // alignof(typename FUNC::AccumulatorType). @@ -175,6 +210,9 @@ class SimpleAggregateAdapter : public Aggregate { static constexpr bool accumulator_is_aligned_ = accumulator_is_aligned::value; + static constexpr bool support_initialize_function_state_ = + support_initialize_function_state::value; + bool isFixedSize() const override { return accumulator_is_fixed_size_; } @@ -301,12 +339,13 @@ class SimpleAggregateAdapter : public Aggregate { if (isNull(groups[i])) { writer.commitNull(); } else { - bool nonNull = group->writeIntermediateResult(writer.current()); + bool nonNull = + group->writeIntermediateResult(writer.current(), state_); writer.commit(nonNull); } } else { bool nonNull = group->writeIntermediateResult( - !isNull(groups[i]), writer.current()); + !isNull(groups[i]), writer.current(), state_); writer.commit(nonNull); } } @@ -332,12 +371,12 @@ class SimpleAggregateAdapter : public Aggregate { if (isNull(groups[i])) { writer.commitNull(); } else { - bool nonNull = group->writeFinalResult(writer.current()); + bool nonNull = group->writeFinalResult(writer.current(), state_); writer.commit(nonNull); } } else { - bool nonNull = - group->writeFinalResult(!isNull(groups[i]), writer.current()); + bool nonNull = group->writeFinalResult( + !isNull(groups[i]), writer.current(), state_); writer.commit(nonNull); } } @@ -350,7 +389,8 @@ class SimpleAggregateAdapter : public Aggregate { folly::Range indices) override { setAllNulls(groups, indices); for (auto i : indices) { - new (groups[i] + offset_) typename FUNC::AccumulatorType(allocator_); + new (groups[i] + offset_) + typename FUNC::AccumulatorType(allocator_, state_); } } @@ -387,7 +427,7 @@ class SimpleAggregateAdapter : public Aggregate { tracker.emplace(groups[row][rowSizeOffset_], *allocator_); } auto group = value(groups[row]); - group->addInput(allocator_, std::get(readers)[row]...); + group->addInput(allocator_, std::get(readers)[row]..., state_); clearNull(groups[row]); }); } else { @@ -400,7 +440,8 @@ class SimpleAggregateAdapter : public Aggregate { bool nonNull = group->addInput( allocator_, OptionalAccessor>{ - &std::get(readers), (int64_t)row}...); + &std::get(readers), (int64_t)row}..., + state_); if (nonNull) { clearNull(groups[row]); } @@ -427,7 +468,8 @@ class SimpleAggregateAdapter : public Aggregate { if constexpr (!accumulator_is_fixed_size_) { tracker.emplace(group[rowSizeOffset_], *allocator_); } - accumulator->addInput(allocator_, std::get(readers)[row]...); + accumulator->addInput( + allocator_, std::get(readers)[row]..., state_); clearNull(group); }); } else { @@ -439,7 +481,8 @@ class SimpleAggregateAdapter : public Aggregate { bool nonNull = accumulator->addInput( allocator_, OptionalAccessor>{ - &std::get(readers), (int64_t)row}...); + &std::get(readers), (int64_t)row}..., + state_); if (nonNull) { clearNull(group); } @@ -511,7 +554,7 @@ class SimpleAggregateAdapter : public Aggregate { tracker.emplace(groups[row][rowSizeOffset_], *allocator_); } auto group = value(groups[row]); - group->combine(allocator_, reader[row]); + group->combine(allocator_, reader[row], state_); clearNull(groups[row]); }); } else { @@ -524,7 +567,8 @@ class SimpleAggregateAdapter : public Aggregate { bool nonNull = group->combine( allocator_, OptionalAccessor{ - &reader, (int64_t)row}); + &reader, (int64_t)row}, + state_); if (nonNull) { clearNull(groups[row]); } @@ -549,7 +593,7 @@ class SimpleAggregateAdapter : public Aggregate { if constexpr (!accumulator_is_fixed_size_) { tracker.emplace(group[rowSizeOffset_], *allocator_); } - accumulator->combine(allocator_, reader[row]); + accumulator->combine(allocator_, reader[row], state_); clearNull(group); }); } else { @@ -561,7 +605,8 @@ class SimpleAggregateAdapter : public Aggregate { bool nonNull = accumulator->combine( allocator_, OptionalAccessor{ - &reader, (int64_t)row}); + &reader, (int64_t)row}, + state_); if (nonNull) { clearNull(group); } diff --git a/velox/exec/tests/CMakeLists.txt b/velox/exec/tests/CMakeLists.txt index ddfb25743d23..3cc4f1aee144 100644 --- a/velox/exec/tests/CMakeLists.txt +++ b/velox/exec/tests/CMakeLists.txt @@ -244,8 +244,13 @@ add_executable(velox_simple_aggregate_test SimpleAggregateAdapterTest.cpp Main.cpp) target_link_libraries( - velox_simple_aggregate_test velox_simple_aggregate velox_exec - velox_functions_aggregates_test_lib gtest gtest_main) + velox_simple_aggregate_test + velox_simple_aggregate + velox_exec + velox_functions_aggregates_test_lib + velox_functions_window_test_lib + gtest + gtest_main) add_library(velox_spiller_join_benchmark_base JoinSpillInputBenchmarkBase.cpp SpillerBenchmarkBase.cpp) diff --git a/velox/exec/tests/SimpleAggregateAdapterTest.cpp b/velox/exec/tests/SimpleAggregateAdapterTest.cpp index 284814801739..18ea484db144 100644 --- a/velox/exec/tests/SimpleAggregateAdapterTest.cpp +++ b/velox/exec/tests/SimpleAggregateAdapterTest.cpp @@ -15,13 +15,16 @@ */ #include "velox/exec/SimpleAggregateAdapter.h" -#include "velox/exec/Aggregate.h" #include "velox/exec/tests/SimpleAggregateFunctionsRegistration.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" +#include "velox/functions/lib/window/tests/WindowTestBase.h" using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; +using facebook::velox::common::testutil::TestValue; using facebook::velox::functions::aggregate::test::AggregationTestBase; +using facebook::velox::window::test::WindowTestBase; namespace facebook::velox::aggregate::test { namespace { @@ -357,6 +360,8 @@ class CountNullsAggregate { using IntermediateType = int64_t; // Intermediate result type. using OutputType = int64_t; // Output vector type. + struct FunctionState {}; + static constexpr bool default_null_behavior_ = false; struct Accumulator { @@ -364,13 +369,16 @@ class CountNullsAggregate { Accumulator() = delete; - explicit Accumulator(HashStringAllocator* /*allocator*/) { + explicit Accumulator( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) { nullsCount_ = 0; } bool addInput( HashStringAllocator* /*allocator*/, - exec::optional_arg_type data) { + exec::optional_arg_type data, + const FunctionState& /*state*/) { if (!data.has_value()) { nullsCount_++; return true; @@ -380,7 +388,8 @@ class CountNullsAggregate { bool combine( HashStringAllocator* /*allocator*/, - exec::optional_arg_type nullsCount) { + exec::optional_arg_type nullsCount, + const FunctionState& /*state*/) { if (nullsCount.has_value()) { nullsCount_ += nullsCount.value(); return true; @@ -388,13 +397,17 @@ class CountNullsAggregate { return false; } - bool writeFinalResult(bool nonNull, exec::out_type& out) { + bool writeFinalResult( + bool nonNull, + exec::out_type& out, + const FunctionState& /*state*/) { return writeResult(nonNull, out); } bool writeIntermediateResult( bool nonNull, - exec::out_type& out) { + exec::out_type& out, + const FunctionState& /*state*/) { return writeResult(nonNull, out); } @@ -471,5 +484,235 @@ TEST_F(SimpleCountNullsAggregationTest, basic) { testAggregations({vectors}, {}, {"simple_count_nulls(c2)"}, {expected}); } +// A testing aggregate function calculates a weighted average by taking an +// int64_t input value and a constant weight value as input. The sum of all +// input values is divided by the sum of all weight values to produce the final +// result. It is used to check for expectations in initialize method. +template +class FunctionStateTestAggregate { + public: + using InputType = Row; // Input vector type wrapped in Row. + using IntermediateType = Row; // Intermediate result type. + using OutputType = double; // Output vector type. + + struct FunctionState { + core::AggregationNode::Step step; + std::vector rawInputTypes; + TypePtr resultType; + std::vector constantInputs; + core::AggregationNode::Step companionStep; + }; + + static void checkConstantInputs( + const std::vector& constantInputs) { + // Check that the constantInputs is {nullptr, 1} + VELOX_CHECK_EQ(constantInputs.size(), 2); + VELOX_CHECK_NULL(constantInputs[0]); + VELOX_CHECK(constantInputs[1]->isConstantEncoding()); + VELOX_CHECK_EQ( + constantInputs[1] + ->template asUnchecked>() + ->valueAt(0), + 1); + } + + static void initialize( + FunctionState& state, + core::AggregationNode::Step step, + const std::vector& rawInputTypes, + const TypePtr& resultType, + const std::vector& constantInputs, + std::optional companionStep) { + std::vector expectedRawInputTypes = {BIGINT(), BIGINT()}; + auto expectedIntermediateType = ROW({BIGINT(), DOUBLE()}); + + if constexpr (testCompanion) { + // Check for companion functions. + VELOX_CHECK(companionStep.has_value()); + if (companionStep.value() == core::AggregationNode::Step::kPartial) { + VELOX_CHECK(rawInputTypes == expectedRawInputTypes); + if (step == core::AggregationNode::Step::kPartial || + step == core::AggregationNode::Step::kSingle) { + // Only check constant inputs in partial and single step. + checkConstantInputs(constantInputs); + } else { + VELOX_CHECK_EQ(constantInputs.size(), 1); + VELOX_CHECK_NULL(constantInputs[0]); + } + } else if ( + companionStep.value() == core::AggregationNode::Step::kIntermediate || + companionStep.value() == core::AggregationNode::Step::kFinal) { + VELOX_CHECK_EQ(rawInputTypes.size(), 1); + VELOX_CHECK(rawInputTypes[0]->equivalent(*expectedIntermediateType)); + + VELOX_CHECK_EQ(constantInputs.size(), 1); + VELOX_CHECK_NULL(constantInputs[0]); + } else { + VELOX_FAIL("Unexpected aggregation step"); + } + } else { + VELOX_CHECK(rawInputTypes == expectedRawInputTypes); + if (step == core::AggregationNode::Step::kPartial || + step == core::AggregationNode::Step::kSingle) { + // Only check constant inputs in partial and single step. + checkConstantInputs(constantInputs); + } else { + VELOX_CHECK_EQ(constantInputs.size(), 1); + VELOX_CHECK_NULL(constantInputs[0]); + } + } + + state.step = step; + state.rawInputTypes = rawInputTypes; + state.resultType = resultType; + state.constantInputs = constantInputs; + } + + struct Accumulator { + int64_t sum{0}; + double count{0}; + + explicit Accumulator( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) {} + + void addInput( + HashStringAllocator* /*allocator*/, + exec::arg_type data, + exec::arg_type /*increment*/, + const FunctionState& state) { + VELOX_CHECK_EQ(state.constantInputs.size(), 2); + VELOX_CHECK(state.constantInputs[1]->isConstantEncoding()); + auto constant = state.constantInputs[1] + ->template asUnchecked>() + ->valueAt(0); + sum += data; + count += constant; + } + + void combine( + HashStringAllocator* /*allocator*/, + exec::arg_type other, + const FunctionState& state) { + VELOX_CHECK(other.at<0>().has_value()); + VELOX_CHECK(other.at<1>().has_value()); + sum += other.at<0>().value(); + count += other.at<1>().value(); + } + + bool writeIntermediateResult( + exec::out_type& out, + const FunctionState& state) { + out = std::make_tuple(sum, count); + return true; + } + + bool writeFinalResult( + exec::out_type& out, + const FunctionState& state) { + out = sum / count; + return true; + } + }; + + using AccumulatorType = Accumulator; +}; + +template +exec::AggregateRegistrationResult registerSimpleFunctionStateTestAggregate( + const std::string& name) { + std::vector> signatures{ + exec::AggregateFunctionSignatureBuilder() + .returnType("DOUBLE") + .intermediateType("ROW(BIGINT, DOUBLE)") + .argumentType("BIGINT") + .argumentType("BIGINT") + .build()}; + + return exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + core::AggregationNode::Step /*step*/, + const std::vector& argTypes, + const TypePtr& resultType, + const core::QueryConfig& /*config*/) + -> std::unique_ptr { + VELOX_CHECK_LE(argTypes.size(), 2, "{} takes at most 2 argument", name); + return std::make_unique< + SimpleAggregateAdapter>>( + resultType); + }, + testCompanion, + true /*overwrite*/); +} + +void registerFunctionStateTestAggregate() { + registerSimpleFunctionStateTestAggregate( + "simple_function_state_agg_main"); + registerSimpleFunctionStateTestAggregate( + "simple_function_state_agg_companion"); +} + +class SimpleFunctionStateAggregationTest : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + + disableTestIncremental(); + disableTestStreaming(); + registerFunctionStateTestAggregate(); + } +}; + +TEST_F(SimpleFunctionStateAggregationTest, aggregate) { + auto inputVectors = makeRowVector({makeFlatVector({1, 2, 3, 4})}); + std::vector finalResult = {2.5}; + auto expected = makeRowVector({makeFlatVector(finalResult)}); + testAggregations( + {inputVectors}, + {}, + {"simple_function_state_agg_main(c0, 1)"}, + {expected}); + testAggregationsWithCompanion( + {inputVectors}, + [](auto& /*builder*/) {}, + {}, + {"simple_function_state_agg_companion(c0, 1)"}, + {{BIGINT(), BIGINT()}}, + {}, + {expected}, + {}); +} + +class SimpleFunctionStateWindowTest : public WindowTestBase { + protected: + void SetUp() override { + WindowTestBase::SetUp(); + + registerFunctionStateTestAggregate(); + } +}; + +TEST_F(SimpleFunctionStateWindowTest, window) { + auto inputVectors = makeRowVector({ + makeFlatVector({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 5}), + makeFlatVector({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), + }); + + auto expected = makeRowVector({ + inputVectors->childAt(0), + inputVectors->childAt(1), + makeFlatVector( + {2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 10.5, 10.5, 12.0}), + }); + WindowTestBase::testWindowFunction( + {inputVectors}, + "simple_function_state_agg_main(c1, 1)", + {"partition by c0"}, + {}, + expected); +} + } // namespace } // namespace facebook::velox::aggregate::test diff --git a/velox/exec/tests/SimpleArrayAggAggregate.cpp b/velox/exec/tests/SimpleArrayAggAggregate.cpp index e3998aadf2a2..591865845b82 100644 --- a/velox/exec/tests/SimpleArrayAggAggregate.cpp +++ b/velox/exec/tests/SimpleArrayAggAggregate.cpp @@ -36,6 +36,8 @@ class ArrayAggAggregate { // Type of output vector. using OutputType = Array>; + struct FunctionState {}; + static constexpr bool default_null_behavior_ = false; static bool toIntermediate( @@ -55,7 +57,9 @@ class ArrayAggAggregate { AccumulatorType() = delete; // Constructor used in initializeNewGroups(). - explicit AccumulatorType(HashStringAllocator* /*allocator*/) + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) : elements_{} {} static constexpr bool is_fixed_size_ = false; @@ -64,7 +68,8 @@ class ArrayAggAggregate { // child-type T wrapped in InputType. bool addInput( HashStringAllocator* allocator, - exec::optional_arg_type> data) { + exec::optional_arg_type> data, + const FunctionState& /*state*/) { elements_.appendValue(data, allocator); return true; } @@ -73,7 +78,8 @@ class ArrayAggAggregate { // exec::optional_arg_type. bool combine( HashStringAllocator* allocator, - exec::optional_arg_type>> other) { + exec::optional_arg_type>> other, + const FunctionState& /*state*/) { if (!other.has_value()) { return false; } @@ -85,7 +91,8 @@ class ArrayAggAggregate { bool writeFinalResult( bool nonNullGroup, - exec::out_type>>& out) { + exec::out_type>>& out, + const FunctionState& /*state*/) { if (!nonNullGroup) { return false; } @@ -95,7 +102,8 @@ class ArrayAggAggregate { bool writeIntermediateResult( bool nonNullGroup, - exec::out_type>>& out) { + exec::out_type>>& out, + const FunctionState& /*state*/) { // If the group's accumulator is null, the corresponding intermediate // result is null too. if (!nonNullGroup) { diff --git a/velox/exec/tests/SimpleAverageAggregate.cpp b/velox/exec/tests/SimpleAverageAggregate.cpp index 9f887f34eadf..9f9494648147 100644 --- a/velox/exec/tests/SimpleAverageAggregate.cpp +++ b/velox/exec/tests/SimpleAverageAggregate.cpp @@ -42,6 +42,8 @@ class AverageAggregate { using OutputType = std::conditional_t, float, double>; + struct FunctionState {}; + static bool toIntermediate( exec::out_type>& out, exec::arg_type in) { @@ -56,14 +58,19 @@ class AverageAggregate { AccumulatorType() = delete; // Constructor used in initializeNewGroups(). - explicit AccumulatorType(HashStringAllocator* /*allocator*/) { + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) { sum_ = 0; count_ = 0; } // addInput expects one parameter of exec::arg_type for each child-type T // wrapped in InputType. - void addInput(HashStringAllocator* /*allocator*/, exec::arg_type data) { + void addInput( + HashStringAllocator* /*allocator*/, + exec::arg_type data, + const FunctionState& /*state*/) { sum_ += data; count_ = checkedPlus(count_, 1); } @@ -71,7 +78,8 @@ class AverageAggregate { // combine expects one parameter of exec::arg_type. void combine( HashStringAllocator* /*allocator*/, - exec::arg_type> other) { + exec::arg_type> other, + const FunctionState& /*state*/) { // Both field of an intermediate result should be non-null because // writeIntermediateResult() never make an intermediate result with a // single null. @@ -81,12 +89,16 @@ class AverageAggregate { count_ = checkedPlus(count_, other.at<1>().value()); } - bool writeFinalResult(exec::out_type& out) { + bool writeFinalResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = sum_ / count_; return true; } - bool writeIntermediateResult(exec::out_type& out) { + bool writeIntermediateResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = std::make_tuple(sum_, count_); return true; } diff --git a/velox/functions/prestosql/aggregates/BitwiseXorAggregate.cpp b/velox/functions/prestosql/aggregates/BitwiseXorAggregate.cpp index aa3a52d5fa28..8a8034ec3f8f 100644 --- a/velox/functions/prestosql/aggregates/BitwiseXorAggregate.cpp +++ b/velox/functions/prestosql/aggregates/BitwiseXorAggregate.cpp @@ -32,6 +32,8 @@ class BitwiseXorAggregate { using OutputType = T; + struct FunctionState {}; + static bool toIntermediate(exec::out_type& out, exec::arg_type in) { out = in; return true; @@ -42,22 +44,34 @@ class BitwiseXorAggregate { AccumulatorType() = delete; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) {} - void addInput(HashStringAllocator* /*allocator*/, exec::arg_type data) { + void addInput( + HashStringAllocator* /*allocator*/, + exec::arg_type data, + const FunctionState& /*state*/) { xor_ ^= data; } - void combine(HashStringAllocator* /*allocator*/, exec::arg_type other) { + void combine( + HashStringAllocator* /*allocator*/, + exec::arg_type other, + const FunctionState& /*state*/) { xor_ ^= other; } - bool writeFinalResult(exec::out_type& out) { + bool writeFinalResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = xor_; return true; } - bool writeIntermediateResult(exec::out_type& out) { + bool writeIntermediateResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = xor_; return true; } diff --git a/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp b/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp index 9bbe9ec8c73b..703a2c5c3367 100644 --- a/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp +++ b/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp @@ -35,6 +35,8 @@ class GeometricMeanAggregate { using OutputType = TResult; + struct FunctionState {}; + static bool toIntermediate( exec::out_type>& out, exec::arg_type in) { @@ -50,30 +52,38 @@ class GeometricMeanAggregate { AccumulatorType() = delete; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) {} void addInput( HashStringAllocator* /*allocator*/, - exec::arg_type data) { + exec::arg_type data, + const FunctionState& /*state*/) { logSum_ += std::log(data); count_ = checkedPlus(count_, 1); } void combine( HashStringAllocator* /*allocator*/, - exec::arg_type> other) { + exec::arg_type> other, + const FunctionState& /*state*/) { VELOX_CHECK(other.at<0>().has_value()); VELOX_CHECK(other.at<1>().has_value()); logSum_ += other.at<0>().value(); count_ = checkedPlus(count_, other.at<1>().value()); } - bool writeFinalResult(exec::out_type& out) { + bool writeFinalResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = std::exp(logSum_ / count_); return true; } - bool writeIntermediateResult(exec::out_type& out) { + bool writeIntermediateResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = std::make_tuple(logSum_, count_); return true; } diff --git a/velox/functions/sparksql/aggregates/CollectListAggregate.cpp b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp index e2c14cfa7969..3a05e4cab1e7 100644 --- a/velox/functions/sparksql/aggregates/CollectListAggregate.cpp +++ b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp @@ -32,6 +32,8 @@ class CollectListAggregate { using OutputType = Array>; + struct FunctionState {}; + /// In Spark, when all inputs are null, the output is an empty array instead /// of null. Therefore, in the writeIntermediateResult and writeFinalResult, /// we still need to output the empty element_ when the group is null. This @@ -51,14 +53,17 @@ class CollectListAggregate { struct AccumulatorType { ValueList elements_; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) : elements_{} {} static constexpr bool is_fixed_size_ = false; bool addInput( HashStringAllocator* allocator, - exec::optional_arg_type> data) { + exec::optional_arg_type> data, + const FunctionState& /*state*/) { if (data.has_value()) { elements_.appendValue(data, allocator); return true; @@ -68,7 +73,8 @@ class CollectListAggregate { bool combine( HashStringAllocator* allocator, - exec::optional_arg_type other) { + exec::optional_arg_type other, + const FunctionState& /*state*/) { if (!other.has_value()) { return false; } @@ -80,7 +86,8 @@ class CollectListAggregate { bool writeIntermediateResult( bool /*nonNullGroup*/, - exec::out_type& out) { + exec::out_type& out, + const FunctionState& /*state*/) { // If the group's accumulator is null, the corresponding intermediate // result is an empty array. copyValueListToArrayWriter(out, elements_); @@ -89,7 +96,8 @@ class CollectListAggregate { bool writeFinalResult( bool /*nonNullGroup*/, - exec::out_type& out) { + exec::out_type& out, + const FunctionState& /*state*/) { // If the group's accumulator is null, the corresponding result is an // empty array. copyValueListToArrayWriter(out, elements_); diff --git a/velox/functions/sparksql/aggregates/DecimalSumAggregate.h b/velox/functions/sparksql/aggregates/DecimalSumAggregate.h index 5019ae1a2ca1..ce25c12af763 100644 --- a/velox/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/velox/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -34,6 +34,8 @@ class DecimalSumAggregate { using OutputType = TSumType; + struct FunctionState {}; + /// Spark's decimal sum doesn't have the concept of a null group, each group /// is initialized with an initial value, where sum = 0 and isEmpty = true. /// The final agg may fallback to being executed in Spark, so the meaning of @@ -70,7 +72,9 @@ class DecimalSumAggregate { AccumulatorType() = delete; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) {} std::optional computeFinalResult() const { if (!sum.has_value()) { @@ -92,7 +96,8 @@ class DecimalSumAggregate { bool addInput( HashStringAllocator* /*allocator*/, - exec::optional_arg_type data) { + exec::optional_arg_type data, + const FunctionState& /*state*/) { if (!data.has_value()) { return false; } @@ -112,7 +117,8 @@ class DecimalSumAggregate { bool combine( HashStringAllocator* /*allocator*/, - exec::optional_arg_type> other) { + exec::optional_arg_type> other, + const FunctionState& /*state*/) { if (!other.has_value()) { return false; } @@ -143,7 +149,8 @@ class DecimalSumAggregate { bool writeIntermediateResult( bool nonNullGroup, - exec::out_type& out) { + exec::out_type& out, + const FunctionState& /*state*/) { if (!nonNullGroup) { // If a group is null, all values in this group are null. In Spark, this // group will be the initial value, where sum is 0 and isEmpty is true. @@ -163,7 +170,10 @@ class DecimalSumAggregate { return true; } - bool writeFinalResult(bool nonNullGroup, exec::out_type& out) { + bool writeFinalResult( + bool nonNullGroup, + exec::out_type& out, + const FunctionState& /*state*/) { if (!nonNullGroup || isEmpty) { // If isEmpty is true, we should set null. return false; diff --git a/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp b/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp index 85f69d1f7c3f..5b5e63b9bea8 100644 --- a/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp +++ b/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp @@ -29,6 +29,8 @@ class RegrReplacementAggregate { /*m2*/ double>; using OutputType = double; + struct FunctionState {}; + static bool toIntermediate( exec::out_type>& out, exec::arg_type in) { @@ -41,11 +43,14 @@ class RegrReplacementAggregate { double avg{0.0}; double m2{0.0}; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) {} void addInput( HashStringAllocator* /*allocator*/, - exec::arg_type data) { + exec::arg_type data, + const FunctionState& /*state*/) { n += 1.0; double delta = data - avg; double deltaN = delta / n; @@ -55,7 +60,8 @@ class RegrReplacementAggregate { void combine( HashStringAllocator* /*allocator*/, - exec::arg_type> other) { + exec::arg_type> other, + const FunctionState& /*state*/) { VELOX_CHECK(other.at<0>().has_value()); VELOX_CHECK(other.at<1>().has_value()); VELOX_CHECK(other.at<2>().has_value()); @@ -72,12 +78,16 @@ class RegrReplacementAggregate { m2 += otherM2 + delta * deltaN * originN * otherN; } - bool writeIntermediateResult(exec::out_type& out) { + bool writeIntermediateResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = std::make_tuple(n, avg, m2); return true; } - bool writeFinalResult(exec::out_type& out) { + bool writeFinalResult( + exec::out_type& out, + const FunctionState& /*state*/) { if (n == 0.0) { return false; }