Skip to content

Commit

Permalink
Expose helper functions to retrieve registered functions from Functio…
Browse files Browse the repository at this point in the history
…nRegistry
  • Loading branch information
pramodsatya committed Aug 8, 2024
1 parent f7e1f35 commit a3778d6
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 83 deletions.
1 change: 1 addition & 0 deletions velox/functions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ velox_add_library(velox_coverage_util CoverageUtil.cpp)

velox_link_libraries(
velox_function_registry
velox_exec
velox_expression
velox_type
velox_core
Expand Down
86 changes: 3 additions & 83 deletions velox/functions/CoverageUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,86 +270,6 @@ void printCoverageMap(
std::cout << out.str() << std::endl;
}

// A function name is a companion function's if the name is an existing
// aggregation functio name followed by a specific suffixes.
bool isCompanionFunctionName(
const std::string& name,
const std::unordered_map<std::string, exec::AggregateFunctionEntry>&
aggregateFunctions) {
auto suffixOffset = name.rfind("_partial");
if (suffixOffset == std::string::npos) {
suffixOffset = name.rfind("_merge_extract");
}
if (suffixOffset == std::string::npos) {
suffixOffset = name.rfind("_merge");
}
if (suffixOffset == std::string::npos) {
suffixOffset = name.rfind("_extract");
}
if (suffixOffset == std::string::npos) {
return false;
}
return aggregateFunctions.count(name.substr(0, suffixOffset)) > 0;
}

/// Returns alphabetically sorted list of scalar functions available in Velox,
/// excluding companion functions.
std::vector<std::string> getSortedScalarNames() {
// Do not print "internal" functions.
static const std::unordered_set<std::string> kBlockList = {"row_constructor"};

auto functions = getFunctionSignatures();

std::vector<std::string> names;
names.reserve(functions.size());
exec::aggregateFunctions().withRLock([&](const auto& aggregateFunctions) {
for (const auto& func : functions) {
const auto& name = func.first;
if (!isCompanionFunctionName(name, aggregateFunctions) &&
kBlockList.count(name) == 0) {
names.emplace_back(name);
}
}
});
std::sort(names.begin(), names.end());
return names;
}

/// Returns alphabetically sorted list of aggregate functions available in
/// Velox, excluding compaion functions.
std::vector<std::string> getSortedAggregateNames() {
std::vector<std::string> names;
exec::aggregateFunctions().withRLock([&](const auto& functions) {
names.reserve(functions.size());
for (const auto& entry : functions) {
if (!isCompanionFunctionName(entry.first, functions)) {
names.push_back(entry.first);
}
}
});
std::sort(names.begin(), names.end());
return names;
}

/// Returns alphabetically sorted list of window functions available in Velox,
/// excluding companion functions.
std::vector<std::string> getSortedWindowNames() {
const auto& functions = exec::windowFunctions();

std::vector<std::string> names;
names.reserve(functions.size());
exec::aggregateFunctions().withRLock([&](const auto& aggregateFunctions) {
for (const auto& entry : functions) {
if (!isCompanionFunctionName(entry.first, aggregateFunctions) &&
aggregateFunctions.count(entry.first) == 0) {
names.emplace_back(entry.first);
}
}
});
std::sort(names.begin(), names.end());
return names;
}

/// Takes a super-set of simple, vector and aggregate function names and prints
/// coverage map showing which of these functions are available in Velox.
/// Companion functions are excluded.
Expand Down Expand Up @@ -411,9 +331,9 @@ void printCoverageMapForAll(const std::string& domain) {
void printVeloxFunctions(
const std::unordered_set<std::string>& linkBlockList,
const std::string& domain) {
auto scalarNames = getSortedScalarNames();
auto aggNames = getSortedAggregateNames();
auto windowNames = getSortedWindowNames();
auto scalarNames = getScalarNames(true);
auto aggNames = getAggregateNames(true);
auto windowNames = getWindowNames(true);

const int columnSize = std::max(
{maxLength(scalarNames),
Expand Down
74 changes: 74 additions & 0 deletions velox/functions/FunctionRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <sstream>
#include "velox/common/base/Exceptions.h"
#include "velox/core/SimpleFunctionMetadata.h"
#include "velox/exec/WindowFunction.h"
#include "velox/expression/FunctionCallToSpecialForm.h"
#include "velox/expression/FunctionSignature.h"
#include "velox/expression/SignatureBinder.h"
Expand Down Expand Up @@ -160,4 +161,77 @@ resolveVectorFunctionWithMetadata(
return exec::resolveVectorFunctionWithMetadata(functionName, argTypes);
}

bool isCompanionFunctionName(
const std::string& name,
const std::unordered_map<std::string, exec::AggregateFunctionEntry>&
aggregateFunctions) {
static const std::vector<std::string> kCompanionFunctionSuffixList = {
"_partial", "_merge_extract", "_merge", "_extract"};
for (const auto& companionFunctionSuffix : kCompanionFunctionSuffixList) {
auto suffixOffset = name.rfind(companionFunctionSuffix);
if (suffixOffset != std::string::npos) {
return aggregateFunctions.count(name.substr(0, suffixOffset)) > 0;
}
}
return false;
}

const std::vector<std::string> getScalarNames(bool sortResult) {
// Do not print "internal" functions.
static const std::unordered_set<std::string> kBlockList = {
"row_constructor", "in", "is_null"};

auto functions = getFunctionSignatures();
std::vector<std::string> names;
names.reserve(functions.size());
exec::aggregateFunctions().withRLock([&](const auto& aggregateFunctions) {
for (const auto& func : functions) {
const auto& name = func.first;
if (!isCompanionFunctionName(name, aggregateFunctions) &&
kBlockList.count(name) == 0) {
names.emplace_back(name);
}
}
});
if (sortResult) {
std::sort(names.begin(), names.end());
}
return names;
}

const std::vector<std::string> getAggregateNames(bool sortResult) {
std::vector<std::string> names;
exec::aggregateFunctions().withRLock([&](const auto& functions) {
names.reserve(functions.size());
for (const auto& entry : functions) {
if (!isCompanionFunctionName(entry.first, functions)) {
names.emplace_back(entry.first);
}
}
});
if (sortResult) {
std::sort(names.begin(), names.end());
}
return names;
}

const std::vector<std::string> getWindowNames(bool sortResult) {
const auto& functions = exec::windowFunctions();
std::vector<std::string> names;
names.reserve(functions.size());
// Removing all aggregates from the window functions list.
exec::aggregateFunctions().withRLock([&](const auto& aggregateFunctions) {
for (const auto& entry : functions) {
if (aggregateFunctions.count(entry.first) == 0) {
names.emplace_back(entry.first);
}
}
});

if (sortResult) {
std::sort(names.begin(), names.end());
}
return names;
}

} // namespace facebook::velox
17 changes: 17 additions & 0 deletions velox/functions/FunctionRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <string>
#include <vector>

#include "velox/exec/Aggregate.h"
#include "velox/expression/FunctionMetadata.h"
#include "velox/expression/FunctionSignature.h"
#include "velox/type/Type.h"
Expand Down Expand Up @@ -95,4 +96,20 @@ resolveVectorFunctionWithMetadata(
/// Clears the function registry.
void clearFunctionRegistry();

/// A function name is a companion function's if the name is an existing
/// aggregation function name followed by one of specific suffixes.
bool isCompanionFunctionName(
const std::string& name,
const std::unordered_map<std::string, exec::AggregateFunctionEntry>&
aggregateFunctions);

/// The following helper functions return the list of scalar, aggregate, and
/// window function names registered in Velox. The list can be sorted by setting
/// the parameter sortNames to true.
const std::vector<std::string> getScalarNames(bool sortResult = false);

const std::vector<std::string> getAggregateNames(bool sortNames = false);

const std::vector<std::string> getWindowNames(bool sortNames = false);

} // namespace facebook::velox
1 change: 1 addition & 0 deletions velox/functions/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ target_link_libraries(
velox_function_registry_test
velox_function_registry
velox_functions_test_lib
velox_window
gmock
gtest
gtest_main)
51 changes: 51 additions & 0 deletions velox/functions/tests/FunctionRegistryTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
#include "velox/functions/FunctionRegistry.h"
#include "velox/functions/Macros.h"
#include "velox/functions/Registerer.h"
#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"
#include "velox/type/Type.h"

namespace facebook::velox {
Expand Down Expand Up @@ -635,6 +637,55 @@ TEST_F(FunctionRegistryTest, resolveWithMetadata) {
EXPECT_FALSE(result.has_value());
}

TEST_F(FunctionRegistryTest, getScalarFunctions) {
functions::prestosql::registerAllScalarFunctions();
const auto scalars = getScalarNames();
const auto expectedScalars = {
"date_parse", "array_sort", "map_filter", "split_to_map"};
for (const auto& name : expectedScalars) {
ASSERT_TRUE(
std::find(scalars.begin(), scalars.end(), name) != scalars.end());
}
}

TEST_F(FunctionRegistryTest, getAggregateFunctions) {
aggregate::prestosql::registerAllAggregateFunctions();
const auto aggregates = getAggregateNames();
const auto expectedAggregates = {
"approx_distinct", "covar_pop", "count", "map_union_sum"};
const auto aggregateCompanions = {
"approx_distinct_merge_extract", "approx_percentile_merge"};
for (const auto& name : expectedAggregates) {
ASSERT_TRUE(
std::find(aggregates.begin(), aggregates.end(), name) !=
aggregates.end());
}
// Verify getAggregateNames does not return companion functions.
for (const auto& companion : aggregateCompanions) {
ASSERT_TRUE(
std::find(aggregates.begin(), aggregates.end(), companion) ==
aggregates.end());
}
}

TEST_F(FunctionRegistryTest, getWindowFunctions) {
window::prestosql::registerAllWindowFunctions();
auto windows = getWindowNames();
auto expectedWindow = {"lead", "ntile", "nth_value", "first_value"};
for (const auto& name : expectedWindow) {
ASSERT_TRUE(
std::find(windows.begin(), windows.end(), name) != windows.end());
}

// Verify getWindowNames does not return aggregate functions.
const auto aggregates = {
"approx_distinct", "covar_pop", "count", "map_union_sum"};
for (const auto& name : aggregates) {
ASSERT_TRUE(
std::find(windows.begin(), windows.end(), name) == windows.end());
}
}

class FunctionRegistryOverwriteTest : public functions::test::FunctionBaseTest {
public:
FunctionRegistryOverwriteTest() {
Expand Down

0 comments on commit a3778d6

Please sign in to comment.