Skip to content

Commit

Permalink
Add FunctionRegistry APIs to retrieve function metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
pramodsatya committed Aug 2, 2024
1 parent 97ad9ad commit 4f11b42
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 82 deletions.
80 changes: 0 additions & 80 deletions velox/functions/CoverageUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,86 +270,6 @@ void printCoverageMap(
std::cout << out.str() << std::endl;
}

// A function name is a companion function's if the name is an existing
// aggregation functio name followed by a specific suffixes.
bool isCompanionFunctionName(
const std::string& name,
const std::unordered_map<std::string, exec::AggregateFunctionEntry>&
aggregateFunctions) {
auto suffixOffset = name.rfind("_partial");
if (suffixOffset == std::string::npos) {
suffixOffset = name.rfind("_merge_extract");
}
if (suffixOffset == std::string::npos) {
suffixOffset = name.rfind("_merge");
}
if (suffixOffset == std::string::npos) {
suffixOffset = name.rfind("_extract");
}
if (suffixOffset == std::string::npos) {
return false;
}
return aggregateFunctions.count(name.substr(0, suffixOffset)) > 0;
}

/// Returns alphabetically sorted list of scalar functions available in Velox,
/// excluding companion functions.
std::vector<std::string> getSortedScalarNames() {
// Do not print "internal" functions.
static const std::unordered_set<std::string> kBlockList = {"row_constructor"};

auto functions = getFunctionSignatures();

std::vector<std::string> names;
names.reserve(functions.size());
exec::aggregateFunctions().withRLock([&](const auto& aggregateFunctions) {
for (const auto& func : functions) {
const auto& name = func.first;
if (!isCompanionFunctionName(name, aggregateFunctions) &&
kBlockList.count(name) == 0) {
names.emplace_back(name);
}
}
});
std::sort(names.begin(), names.end());
return names;
}

/// Returns alphabetically sorted list of aggregate functions available in
/// Velox, excluding compaion functions.
std::vector<std::string> getSortedAggregateNames() {
std::vector<std::string> names;
exec::aggregateFunctions().withRLock([&](const auto& functions) {
names.reserve(functions.size());
for (const auto& entry : functions) {
if (!isCompanionFunctionName(entry.first, functions)) {
names.push_back(entry.first);
}
}
});
std::sort(names.begin(), names.end());
return names;
}

/// Returns alphabetically sorted list of window functions available in Velox,
/// excluding companion functions.
std::vector<std::string> getSortedWindowNames() {
const auto& functions = exec::windowFunctions();

std::vector<std::string> names;
names.reserve(functions.size());
exec::aggregateFunctions().withRLock([&](const auto& aggregateFunctions) {
for (const auto& entry : functions) {
if (!isCompanionFunctionName(entry.first, aggregateFunctions) &&
aggregateFunctions.count(entry.first) == 0) {
names.emplace_back(entry.first);
}
}
});
std::sort(names.begin(), names.end());
return names;
}

/// Takes a super-set of simple, vector and aggregate function names and prints
/// coverage map showing which of these functions are available in Velox.
/// Companion functions are excluded.
Expand Down
89 changes: 89 additions & 0 deletions velox/functions/FunctionRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,93 @@ resolveVectorFunctionWithMetadata(
return exec::resolveVectorFunctionWithMetadata(functionName, argTypes);
}

const bool isCompanionFunctionName(
const std::string& name,
const std::unordered_map<std::string, exec::AggregateFunctionEntry>&
aggregateFunctions) {
// The order of suffixes, '_merge_extract, _merge, _extract', in this list
// should be preserved.
static const std::vector<std::string> kCompanionFunctionSuffixList = {
"_partial", "_merge_extract", "_merge", "_extract"};
for (const auto& companionFunctionSuffix : kCompanionFunctionSuffixList) {
auto suffixOffset = name.rfind(companionFunctionSuffix);
if (suffixOffset != std::string::npos) {
return aggregateFunctions.count(name.substr(0, suffixOffset)) > 0;
}
}
return false;
}

const std::vector<std::string> getSortedScalarNames() {
// Do not print "internal" functions.
static const std::unordered_set<std::string> kBlockList = {
"row_constructor", "in", "is_null"};

auto functions = getFunctionSignatures();

std::vector<std::string> names;
names.reserve(functions.size());
exec::aggregateFunctions().withRLock([&](const auto& aggregateFunctions) {
for (const auto& func : functions) {
const auto& name = func.first;
if (!isCompanionFunctionName(name, aggregateFunctions) &&
kBlockList.count(name) == 0) {
names.emplace_back(name);
}
}
});
std::sort(names.begin(), names.end());
return names;
}

const std::vector<std::string> getSortedAggregateNames() {
std::vector<std::string> names;
exec::aggregateFunctions().withRLock([&](const auto& functions) {
names.reserve(functions.size());
for (const auto& entry : functions) {
if (!isCompanionFunctionName(entry.first, functions)) {
names.push_back(entry.first);
}
}
});
std::sort(names.begin(), names.end());
return names;
}

const std::vector<std::string> getSortedWindowNames() {
const auto& functions = exec::windowFunctions();

std::vector<std::string> names;
names.reserve(functions.size());
exec::aggregateFunctions().withRLock([&](const auto& aggregateFunctions) {
for (const auto& entry : functions) {
if (!isCompanionFunctionName(entry.first, aggregateFunctions) &&
aggregateFunctions.count(entry.first) == 0) {
names.emplace_back(entry.first);
}
}
});
std::sort(names.begin(), names.end());
return names;
}

std::optional<exec::VectorFunctionMetadata> getFunctionMetadata(
const std::string& functionName) {
auto simpleFunctionMetadata =
exec::simpleFunctions().getFunctionSignaturesAndMetadata(functionName);
if (simpleFunctionMetadata.size()) {
// Functions like abs are registered as simple functions for primitive
// types, and as a vector function for complex types like DECIMAL. So do not
// throw an error if function metadata is not found in simple function
// signature map.
return simpleFunctionMetadata.back().first;
}

auto vectorFunctionMetadata = exec::getVectorFunctionMetadata(functionName);
if (vectorFunctionMetadata.has_value()) {
return vectorFunctionMetadata.value();
}
return std::nullopt;
}

} // namespace facebook::velox
22 changes: 22 additions & 0 deletions velox/functions/FunctionRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <string>
#include <vector>

#include "velox/exec/Aggregate.h"
#include "velox/exec/WindowFunction.h"
#include "velox/expression/FunctionMetadata.h"
#include "velox/expression/FunctionSignature.h"
#include "velox/type/Type.h"
Expand Down Expand Up @@ -95,4 +97,24 @@ resolveVectorFunctionWithMetadata(
/// Clears the function registry.
void clearFunctionRegistry();

/// A function name is a companion function's if the name is an existing
/// aggregation function name followed by one of specific suffixes.
const bool isCompanionFunctionName(
const std::string& name,
const std::unordered_map<std::string, exec::AggregateFunctionEntry>&
aggregateFunctions);

/// Returns sorted list of scalar function names available in Velox.
const std::vector<std::string> getSortedScalarNames();

/// Returns sorted list of aggregate function names available in Velox.
const std::vector<std::string> getSortedAggregateNames();

/// Returns sorted list of window function names available in Velox.
const std::vector<std::string> getSortedWindowNames();

/// Get the function metadata corresponding to functionName.
std::optional<exec::VectorFunctionMetadata> getFunctionMetadata(
const std::string& functionName);

} // namespace facebook::velox
26 changes: 24 additions & 2 deletions velox/functions/tests/FunctionRegistryTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ VELOX_DECLARE_VECTOR_FUNCTION(
VELOX_DECLARE_VECTOR_FUNCTION_WITH_METADATA(
udf_vector_func_four,
VectorFuncFour::signatures(),
exec::VectorFunctionMetadataBuilder().deterministic(false).build(),
exec::VectorFunctionMetadataBuilder()
.deterministic(false)
.defaultNullBehavior(false)
.build(),
std::make_unique<VectorFuncFour>());

inline void registerTestFunctions() {
Expand Down Expand Up @@ -627,7 +630,7 @@ TEST_F(FunctionRegistryTest, resolveWithMetadata) {
"vector_func_four", {MAP(INTEGER(), VARCHAR())});
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result->first, *ARRAY(INTEGER()));
EXPECT_TRUE(result->second.defaultNullBehavior);
EXPECT_FALSE(result->second.defaultNullBehavior);
EXPECT_FALSE(result->second.deterministic);
EXPECT_FALSE(result->second.supportsFlattening);

Expand Down Expand Up @@ -655,4 +658,23 @@ TEST_F(FunctionRegistryOverwriteTest, overwrite) {
ASSERT_EQ(signatures.size(), 1);
}

TEST_F(FunctionRegistryTest, functionMetadata) {
auto checkMetadata =
[&](const StringView& name, bool determinism, bool defaultNullBehavior) {
auto metadata = getFunctionMetadata(name);
EXPECT_TRUE(metadata.has_value());
ASSERT_EQ(metadata.value().deterministic, determinism);
ASSERT_EQ(metadata.value().defaultNullBehavior, defaultNullBehavior);
};

// Validate VectorFunctionMetadata for simple functions func_one and func_two.
checkMetadata("func_one", false, true);
checkMetadata("func_two", true, false);

// Validate VectorFunctionMetadata for vector functions vector_func_three and
// vector_func_four.
checkMetadata("vector_func_three", true, true);
checkMetadata("vector_func_four", false, false);
}

} // namespace facebook::velox

0 comments on commit 4f11b42

Please sign in to comment.