Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add isCompanionFunction to function metadata #9250

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions velox/exec/Aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,25 @@ AggregateRegistrationResult registerAggregateFunction(
}

// Register the aggregate as a window function also.
registerAggregateWindowFunction(sanitizedName);
registerAggregateWindowFunction(sanitizedName, metadata);

// Register companion function if needed.
if (registerCompanionFunctions) {
auto companionMetadata = metadata;
companionMetadata.isCompanionFunction = true;

registered.partialFunction =
CompanionFunctionsRegistrar::registerPartialFunction(
name, signatures, overwrite);
name, signatures, companionMetadata, overwrite);
registered.mergeFunction =
CompanionFunctionsRegistrar::registerMergeFunction(
name, signatures, overwrite);
name, signatures, companionMetadata, overwrite);
registered.extractFunction =
CompanionFunctionsRegistrar::registerExtractFunction(
name, signatures, overwrite);
registered.mergeExtractFunction =
CompanionFunctionsRegistrar::registerMergeExtractFunction(
name, signatures, overwrite);
name, signatures, companionMetadata, overwrite);
}
return registered;
}
Expand Down Expand Up @@ -141,6 +144,16 @@ std::vector<AggregateRegistrationResult> registerAggregateFunction(
return registrationResults;
}

const AggregateFunctionMetadata& getAggregateFunctionMetadata(
const std::string& name) {
const auto sanitizedName = sanitizeName(name);
if (auto func = getAggregateFunctionEntry(sanitizedName)) {
return func->metadata;
} else {
VELOX_USER_FAIL("Metadata not found for aggregate function: {}", name);
}
}

std::unordered_map<
std::string,
std::vector<std::shared_ptr<AggregateFunctionSignature>>>
Expand Down
6 changes: 6 additions & 0 deletions velox/exec/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,9 @@ struct AggregateFunctionMetadata {
/// True if results of the aggregation depend on the order of inputs. For
/// example, array_agg is order sensitive while count is not.
bool orderSensitive{true};

/// Indicates if this is a companion function.
bool isCompanionFunction{false};
};
/// Register an aggregate function with the specified name and signatures. If
/// registerCompanionFunctions is true, also register companion aggregate and
Expand Down Expand Up @@ -514,6 +517,9 @@ std::vector<AggregateRegistrationResult> registerAggregateFunction(
bool registerCompanionFunctions,
bool overwrite);

const AggregateFunctionMetadata& getAggregateFunctionMetadata(
const std::string& name);

/// Returns signatures of the aggregate function with the specified name.
/// Returns empty std::optional if function with that name is not found.
std::optional<std::vector<std::shared_ptr<AggregateFunctionSignature>>>
Expand Down
25 changes: 20 additions & 5 deletions velox/exec/AggregateCompanionAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ void AggregateCompanionAdapter::ExtractFunction::apply(
bool CompanionFunctionsRegistrar::registerPartialFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
auto partialSignatures =
CompanionSignatures::partialFunctionSignatures(signatures);
Expand Down Expand Up @@ -280,6 +281,7 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
name,
CompanionSignatures::partialFunctionName(name));
},
metadata,
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
Expand All @@ -288,6 +290,7 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
bool CompanionFunctionsRegistrar::registerMergeFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
auto mergeSignatures =
CompanionSignatures::mergeFunctionSignatures(signatures);
Expand Down Expand Up @@ -320,16 +323,18 @@ bool CompanionFunctionsRegistrar::registerMergeFunction(
name,
CompanionSignatures::mergeFunctionName(name));
},
metadata,
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
}

bool registerAggregateFunction(
bool registerMergeExtractFunctionImpl(
const std::string& name,
const std::string& mergeExtractFunctionName,
const std::vector<std::shared_ptr<AggregateFunctionSignature>>&
mergeExtractSignatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
return exec::registerAggregateFunction(
mergeExtractFunctionName,
Expand Down Expand Up @@ -365,6 +370,7 @@ bool registerAggregateFunction(
name,
mergeExtractFunctionName);
},
metadata,
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
Expand All @@ -373,6 +379,7 @@ bool registerAggregateFunction(
bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
auto groupedSignatures =
CompanionSignatures::groupSignaturesByReturnType(signatures);
Expand All @@ -387,10 +394,11 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionNameWithSuffix(name, type);

registered |= registerAggregateFunction(
registered |= registerMergeExtractFunctionImpl(
name,
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
metadata,
overwrite);
}
return registered;
Expand All @@ -399,10 +407,12 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
signatures)) {
return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
return registerMergeExtractFunctionWithSuffix(
name, signatures, metadata, overwrite);
}

auto mergeExtractSignatures =
Expand All @@ -413,10 +423,11 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(

auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionName(name);
return registerAggregateFunction(
return registerMergeExtractFunctionImpl(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename to registerMergeExtractFunctionInternal.

name,
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
metadata,
overwrite);
}

Expand Down Expand Up @@ -475,6 +486,7 @@ bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix(
std::move(factory),
exec::VectorFunctionMetadataBuilder()
.defaultNullBehavior(false)
.isCompanionFunction(true)
.build(),
overwrite);
}
Expand Down Expand Up @@ -502,7 +514,10 @@ bool CompanionFunctionsRegistrar::registerExtractFunction(
CompanionSignatures::extractFunctionName(originalName),
std::move(extractSignatures),
std::move(factory),
exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(),
exec::VectorFunctionMetadataBuilder()
.defaultNullBehavior(false)
.isCompanionFunction(true)
.build(),
overwrite);
}

Expand Down
4 changes: 4 additions & 0 deletions velox/exec/AggregateCompanionAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class CompanionFunctionsRegistrar {
static bool registerPartialFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite = false);

// When there is already a function of the same name as the merge companion
Expand All @@ -186,6 +187,7 @@ class CompanionFunctionsRegistrar {
static bool registerMergeFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite = false);

// If there are multiple signatures of the original aggregation function
Expand Down Expand Up @@ -213,6 +215,7 @@ class CompanionFunctionsRegistrar {
static bool registerMergeExtractFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite = false);

private:
Expand All @@ -227,6 +230,7 @@ class CompanionFunctionsRegistrar {
static bool registerMergeExtractFunctionWithSuffix(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite);
};

Expand Down
8 changes: 6 additions & 2 deletions velox/exec/AggregateWindow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,9 @@ class AggregateWindowFunction : public exec::WindowFunction {

} // namespace

void registerAggregateWindowFunction(const std::string& name) {
void registerAggregateWindowFunction(
const std::string& name,
const AggregateFunctionMetadata& metadata) {
auto aggregateFunctionSignatures = exec::getAggregateFunctionSignatures(name);
if (aggregateFunctionSignatures.has_value()) {
// This copy is needed to obtain a vector of the base FunctionSignaturePtr
Expand All @@ -410,7 +412,9 @@ void registerAggregateWindowFunction(const std::string& name) {
exec::registerWindowFunction(
name,
std::move(signatures),
{exec::WindowFunction::ProcessMode::kRows, true},
{exec::WindowFunction::ProcessMode::kRows,
true,
metadata.isCompanionFunction},
[name](
const std::vector<exec::WindowFunctionArg>& args,
const TypePtr& resultType,
Expand Down
5 changes: 4 additions & 1 deletion velox/exec/AggregateWindow.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
*/
#pragma once
#include <string>
#include "velox/exec/Aggregate.h"

namespace facebook::velox::exec {

void registerAggregateWindowFunction(const std::string& name);
void registerAggregateWindowFunction(
const std::string& name,
const AggregateFunctionMetadata& metadata);

} // namespace facebook::velox::exec
7 changes: 4 additions & 3 deletions velox/exec/WindowFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ class WindowFunction {
kRows,
};

/// Indicates whether this is an aggregate window function and its process
/// unit.
/// Indicates whether this is an aggregate window function, whether it is a
/// companion function, and its process unit.
struct Metadata {
ProcessMode processMode;
bool isAggregate;
bool isCompanionFunction;

static Metadata defaultMetadata() {
static Metadata defaultValue{ProcessMode::kPartition, false};
static Metadata defaultValue{ProcessMode::kPartition, false, false};
return defaultValue;
}
};
Expand Down
19 changes: 19 additions & 0 deletions velox/exec/tests/WindowFunctionRegistryTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "velox/exec/WindowFunction.h"
#include "velox/expression/SignatureBinder.h"
#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h"
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"

namespace facebook::velox::exec::test {
Expand Down Expand Up @@ -134,4 +135,22 @@ TEST_F(WindowFunctionRegistryTest, prefix) {
}
}

TEST_F(WindowFunctionRegistryTest, isCompanionFunction) {
aggregate::prestosql::registerAllAggregateFunctions();
window::prestosql::registerAllWindowFunctions();
const auto windowFunctions = {
"count", "lead", "ntile", "nth_value", "first_value", "map_union_sum"};
const auto companionFunctions = {
"approx_most_frequent_partial",
"approx_percentile_merge",
"arbitrary_merge_extract"};

for (const auto& function : windowFunctions) {
ASSERT_FALSE(getWindowFunctionMetadata(function).isCompanionFunction);
}
for (const auto& function : companionFunctions) {
ASSERT_TRUE(getWindowFunctionMetadata(function).isCompanionFunction);
}
}

} // namespace facebook::velox::exec::test
8 changes: 8 additions & 0 deletions velox/expression/FunctionMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ struct VectorFunctionMetadata {
/// In this case, 'rows' in VectorFunction::apply will point only to positions
/// for which all arguments are not null.
bool defaultNullBehavior{true};

/// Indicates if this is a companion function.
bool isCompanionFunction{false};
};

class VectorFunctionMetadataBuilder {
Expand All @@ -59,6 +62,11 @@ class VectorFunctionMetadataBuilder {
return *this;
}

VectorFunctionMetadataBuilder& isCompanionFunction(bool isCompanionFunction) {
metadata_.isCompanionFunction = isCompanionFunction;
return *this;
}

const VectorFunctionMetadata& build() const {
return metadata_;
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/lib/aggregates/BitwiseAggregateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ exec::AggregateRegistrationResult registerBitwise(
inputType->kindName());
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*isCompanionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/lib/window/Rank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ void registerRankInternal(
exec::registerWindowFunction(
name,
std::move(signatures),
{exec::WindowFunction::ProcessMode::kRows, false},
{exec::WindowFunction::ProcessMode::kRows, false, false},
std::move(windowFunctionFactory));
} else {
exec::registerWindowFunction(
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/lib/window/RowNumber.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void registerRowNumber(const std::string& name, TypeKind resultTypeKind) {
exec::registerWindowFunction(
name,
std::move(signatures),
{exec::WindowFunction::ProcessMode::kRows, false},
{exec::WindowFunction::ProcessMode::kRows, false, false},
[name](
const std::vector<exec::WindowFunctionArg>& /*args*/,
const TypePtr& resultType,
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/AverageAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void registerAverageAggregate(
}
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*isCompanionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/BoolAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ exec::AggregateRegistrationResult registerBool(
inputType->kindName());
return std::make_unique<T>();
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*isCompanionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/ChecksumAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ void registerChecksumAggregate(

return std::make_unique<ChecksumAggregate>(VARBINARY());
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*isCompanionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/CountAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ void registerCountAggregate(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<CountAggregate>();
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*isCompanionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/CountIfAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ void registerCountIfAggregate(

return std::make_unique<CountIfAggregate>();
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*isCompanionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
Loading
Loading