From 64e6cc8ae6a550c9df10611de7df36008422ef31 Mon Sep 17 00:00:00 2001 From: Pramod Date: Mon, 25 Mar 2024 21:37:29 -0700 Subject: [PATCH] Add isCompanionFunction to function metadata --- velox/exec/Aggregate.cpp | 21 +++++++-- velox/exec/Aggregate.h | 6 +++ velox/exec/AggregateCompanionAdapter.cpp | 25 ++++++++--- velox/exec/AggregateCompanionAdapter.h | 4 ++ velox/exec/AggregateWindow.cpp | 8 +++- velox/exec/AggregateWindow.h | 5 ++- velox/exec/WindowFunction.h | 7 +-- .../exec/tests/WindowFunctionRegistryTest.cpp | 19 ++++++++ velox/expression/FunctionMetadata.h | 8 ++++ .../lib/aggregates/BitwiseAggregateBase.h | 2 +- velox/functions/lib/window/Rank.cpp | 2 +- velox/functions/lib/window/RowNumber.cpp | 2 +- .../prestosql/aggregates/AverageAggregate.cpp | 2 +- .../prestosql/aggregates/BoolAggregates.cpp | 2 +- .../aggregates/ChecksumAggregate.cpp | 2 +- .../prestosql/aggregates/CountAggregate.cpp | 2 +- .../prestosql/aggregates/CountIfAggregate.cpp | 2 +- .../aggregates/GeometricMeanAggregate.cpp | 2 +- .../aggregates/HistogramAggregate.cpp | 2 +- .../prestosql/aggregates/MinMaxAggregates.cpp | 2 +- .../prestosql/aggregates/ReduceAgg.cpp | 2 +- .../prestosql/aggregates/SumAggregate.cpp | 2 +- .../tests/AggregationFunctionRegTest.cpp | 43 +++++++++++-------- .../functions/tests/FunctionRegistryTest.cpp | 18 ++++++++ 24 files changed, 145 insertions(+), 45 deletions(-) diff --git a/velox/exec/Aggregate.cpp b/velox/exec/Aggregate.cpp index 1bab06a1a833..e5697fe7b660 100644 --- a/velox/exec/Aggregate.cpp +++ b/velox/exec/Aggregate.cpp @@ -80,22 +80,25 @@ AggregateRegistrationResult registerAggregateFunction( } // Register the aggregate as a window function also. - registerAggregateWindowFunction(sanitizedName); + registerAggregateWindowFunction(sanitizedName, metadata); // Register companion function if needed. if (registerCompanionFunctions) { + auto companionMetadata = metadata; + companionMetadata.isCompanionFunction = true; + registered.partialFunction = CompanionFunctionsRegistrar::registerPartialFunction( - name, signatures, overwrite); + name, signatures, companionMetadata, overwrite); registered.mergeFunction = CompanionFunctionsRegistrar::registerMergeFunction( - name, signatures, overwrite); + name, signatures, companionMetadata, overwrite); registered.extractFunction = CompanionFunctionsRegistrar::registerExtractFunction( name, signatures, overwrite); registered.mergeExtractFunction = CompanionFunctionsRegistrar::registerMergeExtractFunction( - name, signatures, overwrite); + name, signatures, companionMetadata, overwrite); } return registered; } @@ -141,6 +144,16 @@ std::vector registerAggregateFunction( return registrationResults; } +const AggregateFunctionMetadata& getAggregateFunctionMetadata( + const std::string& name) { + const auto sanitizedName = sanitizeName(name); + if (auto func = getAggregateFunctionEntry(sanitizedName)) { + return func->metadata; + } else { + VELOX_USER_FAIL("Metadata not found for aggregate function: {}", name); + } +} + std::unordered_map< std::string, std::vector>> diff --git a/velox/exec/Aggregate.h b/velox/exec/Aggregate.h index d6bc12aefcde..36eac5063430 100644 --- a/velox/exec/Aggregate.h +++ b/velox/exec/Aggregate.h @@ -476,6 +476,9 @@ struct AggregateFunctionMetadata { /// True if results of the aggregation depend on the order of inputs. For /// example, array_agg is order sensitive while count is not. bool orderSensitive{true}; + + /// Indicates if this is a companion function. + bool isCompanionFunction{false}; }; /// Register an aggregate function with the specified name and signatures. If /// registerCompanionFunctions is true, also register companion aggregate and @@ -514,6 +517,9 @@ std::vector registerAggregateFunction( bool registerCompanionFunctions, bool overwrite); +const AggregateFunctionMetadata& getAggregateFunctionMetadata( + const std::string& name); + /// Returns signatures of the aggregate function with the specified name. /// Returns empty std::optional if function with that name is not found. std::optional>> diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index d0c36e660224..54145dd311ba 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -249,6 +249,7 @@ void AggregateCompanionAdapter::ExtractFunction::apply( bool CompanionFunctionsRegistrar::registerPartialFunction( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite) { auto partialSignatures = CompanionSignatures::partialFunctionSignatures(signatures); @@ -280,6 +281,7 @@ bool CompanionFunctionsRegistrar::registerPartialFunction( name, CompanionSignatures::partialFunctionName(name)); }, + metadata, /*registerCompanionFunctions*/ false, overwrite) .mainFunction; @@ -288,6 +290,7 @@ bool CompanionFunctionsRegistrar::registerPartialFunction( bool CompanionFunctionsRegistrar::registerMergeFunction( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite) { auto mergeSignatures = CompanionSignatures::mergeFunctionSignatures(signatures); @@ -320,16 +323,18 @@ bool CompanionFunctionsRegistrar::registerMergeFunction( name, CompanionSignatures::mergeFunctionName(name)); }, + metadata, /*registerCompanionFunctions*/ false, overwrite) .mainFunction; } -bool registerAggregateFunction( +bool registerMergeExtractFunctionInternal( const std::string& name, const std::string& mergeExtractFunctionName, const std::vector>& mergeExtractSignatures, + const AggregateFunctionMetadata& metadata, bool overwrite) { return exec::registerAggregateFunction( mergeExtractFunctionName, @@ -365,6 +370,7 @@ bool registerAggregateFunction( name, mergeExtractFunctionName); }, + metadata, /*registerCompanionFunctions*/ false, overwrite) .mainFunction; @@ -373,6 +379,7 @@ bool registerAggregateFunction( bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite) { auto groupedSignatures = CompanionSignatures::groupSignaturesByReturnType(signatures); @@ -387,10 +394,11 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix( auto mergeExtractFunctionName = CompanionSignatures::mergeExtractFunctionNameWithSuffix(name, type); - registered |= registerAggregateFunction( + registered |= registerMergeExtractFunctionInternal( name, mergeExtractFunctionName, std::move(mergeExtractSignatures), + metadata, overwrite); } return registered; @@ -399,10 +407,12 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix( bool CompanionFunctionsRegistrar::registerMergeExtractFunction( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite) { if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures( signatures)) { - return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite); + return registerMergeExtractFunctionWithSuffix( + name, signatures, metadata, overwrite); } auto mergeExtractSignatures = @@ -413,10 +423,11 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction( auto mergeExtractFunctionName = CompanionSignatures::mergeExtractFunctionName(name); - return registerAggregateFunction( + return registerMergeExtractFunctionInternal( name, mergeExtractFunctionName, std::move(mergeExtractSignatures), + metadata, overwrite); } @@ -475,6 +486,7 @@ bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix( std::move(factory), exec::VectorFunctionMetadataBuilder() .defaultNullBehavior(false) + .isCompanionFunction(true) .build(), overwrite); } @@ -502,7 +514,10 @@ bool CompanionFunctionsRegistrar::registerExtractFunction( CompanionSignatures::extractFunctionName(originalName), std::move(extractSignatures), std::move(factory), - exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(), + exec::VectorFunctionMetadataBuilder() + .defaultNullBehavior(false) + .isCompanionFunction(true) + .build(), overwrite); } diff --git a/velox/exec/AggregateCompanionAdapter.h b/velox/exec/AggregateCompanionAdapter.h index 91b7c3a7bed8..29c5ef246b21 100644 --- a/velox/exec/AggregateCompanionAdapter.h +++ b/velox/exec/AggregateCompanionAdapter.h @@ -178,6 +178,7 @@ class CompanionFunctionsRegistrar { static bool registerPartialFunction( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite = false); // When there is already a function of the same name as the merge companion @@ -186,6 +187,7 @@ class CompanionFunctionsRegistrar { static bool registerMergeFunction( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite = false); // If there are multiple signatures of the original aggregation function @@ -213,6 +215,7 @@ class CompanionFunctionsRegistrar { static bool registerMergeExtractFunction( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite = false); private: @@ -227,6 +230,7 @@ class CompanionFunctionsRegistrar { static bool registerMergeExtractFunctionWithSuffix( const std::string& name, const std::vector& signatures, + const AggregateFunctionMetadata& metadata, bool overwrite); }; diff --git a/velox/exec/AggregateWindow.cpp b/velox/exec/AggregateWindow.cpp index 2bdad5342c6e..f863b2bbb63f 100644 --- a/velox/exec/AggregateWindow.cpp +++ b/velox/exec/AggregateWindow.cpp @@ -397,7 +397,9 @@ class AggregateWindowFunction : public exec::WindowFunction { } // namespace -void registerAggregateWindowFunction(const std::string& name) { +void registerAggregateWindowFunction( + const std::string& name, + const AggregateFunctionMetadata& metadata) { auto aggregateFunctionSignatures = exec::getAggregateFunctionSignatures(name); if (aggregateFunctionSignatures.has_value()) { // This copy is needed to obtain a vector of the base FunctionSignaturePtr @@ -410,7 +412,9 @@ void registerAggregateWindowFunction(const std::string& name) { exec::registerWindowFunction( name, std::move(signatures), - {exec::WindowFunction::ProcessMode::kRows, true}, + {exec::WindowFunction::ProcessMode::kRows, + true, + metadata.isCompanionFunction}, [name]( const std::vector& args, const TypePtr& resultType, diff --git a/velox/exec/AggregateWindow.h b/velox/exec/AggregateWindow.h index 09bb55f06199..03845e9796dd 100644 --- a/velox/exec/AggregateWindow.h +++ b/velox/exec/AggregateWindow.h @@ -15,9 +15,12 @@ */ #pragma once #include +#include "velox/exec/Aggregate.h" namespace facebook::velox::exec { -void registerAggregateWindowFunction(const std::string& name); +void registerAggregateWindowFunction( + const std::string& name, + const AggregateFunctionMetadata& metadata); } // namespace facebook::velox::exec diff --git a/velox/exec/WindowFunction.h b/velox/exec/WindowFunction.h index e9bea92ee2c2..71a49502a09f 100644 --- a/velox/exec/WindowFunction.h +++ b/velox/exec/WindowFunction.h @@ -43,14 +43,15 @@ class WindowFunction { kRows, }; - /// Indicates whether this is an aggregate window function and its process - /// unit. + /// Indicates whether this is an aggregate window function, whether it is a + /// companion function, and its process unit. struct Metadata { ProcessMode processMode; bool isAggregate; + bool isCompanionFunction; static Metadata defaultMetadata() { - static Metadata defaultValue{ProcessMode::kPartition, false}; + static Metadata defaultValue{ProcessMode::kPartition, false, false}; return defaultValue; } }; diff --git a/velox/exec/tests/WindowFunctionRegistryTest.cpp b/velox/exec/tests/WindowFunctionRegistryTest.cpp index f78b18a05024..5fe5ef02b32c 100644 --- a/velox/exec/tests/WindowFunctionRegistryTest.cpp +++ b/velox/exec/tests/WindowFunctionRegistryTest.cpp @@ -17,6 +17,7 @@ #include "velox/exec/WindowFunction.h" #include "velox/expression/SignatureBinder.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" namespace facebook::velox::exec::test { @@ -134,4 +135,22 @@ TEST_F(WindowFunctionRegistryTest, prefix) { } } +TEST_F(WindowFunctionRegistryTest, isCompanionFunction) { + aggregate::prestosql::registerAllAggregateFunctions(); + window::prestosql::registerAllWindowFunctions(); + const auto windowFunctions = { + "count", "lead", "ntile", "nth_value", "first_value", "map_union_sum"}; + const auto companionFunctions = { + "approx_most_frequent_partial", + "approx_percentile_merge", + "arbitrary_merge_extract"}; + + for (const auto& function : windowFunctions) { + ASSERT_FALSE(getWindowFunctionMetadata(function).isCompanionFunction); + } + for (const auto& function : companionFunctions) { + ASSERT_TRUE(getWindowFunctionMetadata(function).isCompanionFunction); + } +} + } // namespace facebook::velox::exec::test diff --git a/velox/expression/FunctionMetadata.h b/velox/expression/FunctionMetadata.h index bc354110b955..706b306f7962 100644 --- a/velox/expression/FunctionMetadata.h +++ b/velox/expression/FunctionMetadata.h @@ -40,6 +40,9 @@ struct VectorFunctionMetadata { /// In this case, 'rows' in VectorFunction::apply will point only to positions /// for which all arguments are not null. bool defaultNullBehavior{true}; + + /// Indicates if this is a companion function. + bool isCompanionFunction{false}; }; class VectorFunctionMetadataBuilder { @@ -59,6 +62,11 @@ class VectorFunctionMetadataBuilder { return *this; } + VectorFunctionMetadataBuilder& isCompanionFunction(bool isCompanionFunction) { + metadata_.isCompanionFunction = isCompanionFunction; + return *this; + } + const VectorFunctionMetadata& build() const { return metadata_; } diff --git a/velox/functions/lib/aggregates/BitwiseAggregateBase.h b/velox/functions/lib/aggregates/BitwiseAggregateBase.h index cb9405d2c3d6..bde1179f14ab 100644 --- a/velox/functions/lib/aggregates/BitwiseAggregateBase.h +++ b/velox/functions/lib/aggregates/BitwiseAggregateBase.h @@ -114,7 +114,7 @@ exec::AggregateRegistrationResult registerBitwise( inputType->kindName()); } }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*isCompanionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/lib/window/Rank.cpp b/velox/functions/lib/window/Rank.cpp index 625d2bd3b977..143aae3008c3 100644 --- a/velox/functions/lib/window/Rank.cpp +++ b/velox/functions/lib/window/Rank.cpp @@ -113,7 +113,7 @@ void registerRankInternal( exec::registerWindowFunction( name, std::move(signatures), - {exec::WindowFunction::ProcessMode::kRows, false}, + {exec::WindowFunction::ProcessMode::kRows, false, false}, std::move(windowFunctionFactory)); } else { exec::registerWindowFunction( diff --git a/velox/functions/lib/window/RowNumber.cpp b/velox/functions/lib/window/RowNumber.cpp index 81fc7e5e0c6d..5c524751d7f6 100644 --- a/velox/functions/lib/window/RowNumber.cpp +++ b/velox/functions/lib/window/RowNumber.cpp @@ -75,7 +75,7 @@ void registerRowNumber(const std::string& name, TypeKind resultTypeKind) { exec::registerWindowFunction( name, std::move(signatures), - {exec::WindowFunction::ProcessMode::kRows, false}, + {exec::WindowFunction::ProcessMode::kRows, false, false}, [name]( const std::vector& /*args*/, const TypePtr& resultType, diff --git a/velox/functions/prestosql/aggregates/AverageAggregate.cpp b/velox/functions/prestosql/aggregates/AverageAggregate.cpp index d54c306cceac..6c9d18f3de2c 100644 --- a/velox/functions/prestosql/aggregates/AverageAggregate.cpp +++ b/velox/functions/prestosql/aggregates/AverageAggregate.cpp @@ -142,7 +142,7 @@ void registerAverageAggregate( } } }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*isCompanionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/BoolAggregates.cpp b/velox/functions/prestosql/aggregates/BoolAggregates.cpp index 5095cb373cca..e7c594d9e576 100644 --- a/velox/functions/prestosql/aggregates/BoolAggregates.cpp +++ b/velox/functions/prestosql/aggregates/BoolAggregates.cpp @@ -209,7 +209,7 @@ exec::AggregateRegistrationResult registerBool( inputType->kindName()); return std::make_unique(); }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*isCompanionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/ChecksumAggregate.cpp b/velox/functions/prestosql/aggregates/ChecksumAggregate.cpp index 458be12c497c..93acaab0c94d 100644 --- a/velox/functions/prestosql/aggregates/ChecksumAggregate.cpp +++ b/velox/functions/prestosql/aggregates/ChecksumAggregate.cpp @@ -262,7 +262,7 @@ void registerChecksumAggregate( return std::make_unique(VARBINARY()); }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*isCompanionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/CountAggregate.cpp b/velox/functions/prestosql/aggregates/CountAggregate.cpp index 8b484ef4c8fa..4c539e4cf207 100644 --- a/velox/functions/prestosql/aggregates/CountAggregate.cpp +++ b/velox/functions/prestosql/aggregates/CountAggregate.cpp @@ -182,7 +182,7 @@ void registerCountAggregate( argTypes.size(), 1, "{} takes at most one argument", name); return std::make_unique(); }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*isCompanionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/CountIfAggregate.cpp b/velox/functions/prestosql/aggregates/CountIfAggregate.cpp index 5da65ec7ce4a..f844aec97735 100644 --- a/velox/functions/prestosql/aggregates/CountIfAggregate.cpp +++ b/velox/functions/prestosql/aggregates/CountIfAggregate.cpp @@ -206,7 +206,7 @@ void registerCountIfAggregate( return std::make_unique(); }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*isCompanionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp b/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp index 9bbe9ec8c73b..394db46eb469 100644 --- a/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp +++ b/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp @@ -136,7 +136,7 @@ void registerGeometricMeanAggregate( inputType->toString()); } }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*isCompanionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/HistogramAggregate.cpp b/velox/functions/prestosql/aggregates/HistogramAggregate.cpp index c2481925b42f..fcd0e5909699 100644 --- a/velox/functions/prestosql/aggregates/HistogramAggregate.cpp +++ b/velox/functions/prestosql/aggregates/HistogramAggregate.cpp @@ -632,7 +632,7 @@ void registerHistogramAggregate( inputType->toString()); } }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*isCompanionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp b/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp index 871a62733d33..046f717fa8e1 100644 --- a/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -514,7 +514,7 @@ exec::AggregateRegistrationResult registerMinMax( } } }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*isCompanionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/ReduceAgg.cpp b/velox/functions/prestosql/aggregates/ReduceAgg.cpp index 7c7553ffd3e7..991bb40d48cb 100644 --- a/velox/functions/prestosql/aggregates/ReduceAgg.cpp +++ b/velox/functions/prestosql/aggregates/ReduceAgg.cpp @@ -817,7 +817,7 @@ void registerReduceAgg( const core::QueryConfig& config) -> std::unique_ptr { return std::make_unique(resultType); }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*isCompanionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/SumAggregate.cpp b/velox/functions/prestosql/aggregates/SumAggregate.cpp index 0f6278e96a75..ab6c20ab9432 100644 --- a/velox/functions/prestosql/aggregates/SumAggregate.cpp +++ b/velox/functions/prestosql/aggregates/SumAggregate.cpp @@ -110,7 +110,7 @@ exec::AggregateRegistrationResult registerSum( inputType->kindName()); } }, - {false /*orderSensitive*/}, + {false /*orderSensitive*/, false /*isCompanionFunction*/}, withCompanionFunctions, overwrite); } diff --git a/velox/functions/prestosql/aggregates/tests/AggregationFunctionRegTest.cpp b/velox/functions/prestosql/aggregates/tests/AggregationFunctionRegTest.cpp index 40777e874efb..4d735c105e49 100644 --- a/velox/functions/prestosql/aggregates/tests/AggregationFunctionRegTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/AggregationFunctionRegTest.cpp @@ -58,6 +58,7 @@ TEST_F(AggregationFunctionRegTest, orderSensitive) { // Remove all functions and check for no entries. clearAndCheckRegistry(); + aggregate::prestosql::registerAllAggregateFunctions(); std::set nonOrderSensitiveFunctions = { "sum", "avg", @@ -74,27 +75,16 @@ TEST_F(AggregationFunctionRegTest, orderSensitive) { "geometric_mean", "histogram", "reduce_agg"}; - aggregate::prestosql::registerAllAggregateFunctions(); - exec::aggregateFunctions().withRLock([&](const auto& aggrFuncMap) { - for (const auto& entry : aggrFuncMap) { - if (!entry.second.metadata.orderSensitive) { - EXPECT_EQ(1, nonOrderSensitiveFunctions.erase(entry.first)); - } - } - }); - EXPECT_EQ(0, nonOrderSensitiveFunctions.size()); + for (const auto& entry : nonOrderSensitiveFunctions) { + ASSERT_FALSE(exec::getAggregateFunctionMetadata(entry).orderSensitive); + } // Test some but not all order sensitive functions std::set orderSensitiveFunctions = { "array_agg", "arbitrary", "any_value", "map_agg", "map_union", "set_agg"}; - exec::aggregateFunctions().withRLock([&](const auto& aggrFuncMap) { - for (const auto& entry : aggrFuncMap) { - if (entry.second.metadata.orderSensitive) { - orderSensitiveFunctions.erase(entry.first); - } - } - }); - EXPECT_EQ(0, orderSensitiveFunctions.size()); + for (const auto& entry : orderSensitiveFunctions) { + ASSERT_TRUE(exec::getAggregateFunctionMetadata(entry).orderSensitive); + } } TEST_F(AggregationFunctionRegTest, prestoSupportedSignatures) { @@ -121,4 +111,23 @@ TEST_F(AggregationFunctionRegTest, prestoSupportedSignatures) { clearAndCheckRegistry(); } +TEST_F(AggregationFunctionRegTest, isCompanionFunction) { + // Remove all functions and check for no entries. + clearAndCheckRegistry(); + + aggregate::prestosql::registerAllAggregateFunctions(); + const auto aggregates = {"approx_distinct", "count", "sum"}; + const auto companionFunctions = { + "approx_distinct_merge", "approx_distinct_partial"}; + + for (const auto& function : aggregates) { + ASSERT_FALSE( + exec::getAggregateFunctionMetadata(function).isCompanionFunction); + } + for (const auto& function : companionFunctions) { + ASSERT_TRUE( + exec::getAggregateFunctionMetadata(function).isCompanionFunction); + } +} + } // namespace facebook::velox::aggregate::test diff --git a/velox/functions/tests/FunctionRegistryTest.cpp b/velox/functions/tests/FunctionRegistryTest.cpp index 0456c77c6fd7..bbb72427d14d 100644 --- a/velox/functions/tests/FunctionRegistryTest.cpp +++ b/velox/functions/tests/FunctionRegistryTest.cpp @@ -24,6 +24,7 @@ #include "velox/functions/FunctionRegistry.h" #include "velox/functions/Macros.h" #include "velox/functions/Registerer.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" #include "velox/functions/tests/RegistryTestUtil.h" @@ -356,6 +357,23 @@ TEST_F(FunctionRegistryTest, isDeterministic) { ASSERT_FALSE(isDeterministic("not_found_function").has_value()); } +TEST_F(FunctionRegistryTest, isCompanionFunction) { + functions::prestosql::registerAllScalarFunctions(); + // extract aggregate companion functions are registered as vector functions. + aggregate::prestosql::registerAllAggregateFunctions(); + const auto functions = {"array_frequency", "bitwise_left_shift", "ceil"}; + const auto companionFunctions = { + "array_agg_extract", "arbitrary_extract", "bitwise_and_agg_extract"}; + + for (const auto& function : functions) { + ASSERT_FALSE( + exec::getVectorFunctionMetadata(function)->isCompanionFunction); + } + for (const auto& function : companionFunctions) { + ASSERT_TRUE(exec::getVectorFunctionMetadata(function)->isCompanionFunction); + } +} + template struct TestFunction { VELOX_DEFINE_FUNCTION_TYPES(T);