Skip to content

Commit

Permalink
[native] Retrieve Json function metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
pramodsatya committed Apr 24, 2024
1 parent 369f9d3 commit 6386575
Show file tree
Hide file tree
Showing 11 changed files with 12,618 additions and 11,945 deletions.
14 changes: 14 additions & 0 deletions presto-native-execution/presto_cpp/main/PrestoServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "presto_cpp/main/operators/PartitionAndSerialize.h"
#include "presto_cpp/main/operators/ShuffleRead.h"
#include "presto_cpp/main/operators/UnsafeRowExchangeSource.h"
#include "presto_cpp/main/types/JsonFunctionMetadata.h"
#include "presto_cpp/main/types/PrestoToVeloxConnector.h"
#include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h"
#include "velox/common/base/Counters.h"
Expand Down Expand Up @@ -347,6 +348,14 @@ void PrestoServer::run() {
http::kMimeTypeApplicationJson)
.sendWithEOM();
});
httpServer_->registerGet(
"/v1/info/workerFunctionSignatures",
[server = this](
proxygen::HTTPMessage* /*message*/,
const std::vector<std::unique_ptr<folly::IOBuf>>& /*body*/,
proxygen::ResponseHandler* downstream) {
server->getFunctionSignatures(downstream);
});

registerFunctions();
registerRemoteFunctions();
Expand Down Expand Up @@ -1140,6 +1149,11 @@ void PrestoServer::reportNodeStatus(proxygen::ResponseHandler* downstream) {
http::sendOkResponse(downstream, json(fetchNodeStatus()));
}

void PrestoServer::getFunctionSignatures(
proxygen::ResponseHandler* downstream) {
http::sendOkResponse(downstream, getJsonFunctionMetadata());
}

protocol::NodeStatus PrestoServer::fetchNodeStatus() {
auto systemConfig = SystemConfig::instance();
const int64_t nodeMemoryGb = systemConfig->systemMemoryGb();
Expand Down
2 changes: 2 additions & 0 deletions presto-native-execution/presto_cpp/main/PrestoServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ class PrestoServer {

void reportNodeStatus(proxygen::ResponseHandler* downstream);

void getFunctionSignatures(proxygen::ResponseHandler* downstream);

protocol::NodeStatus fetchNodeStatus();

void populateMemAndCPUInfo();
Expand Down
4 changes: 4 additions & 0 deletions presto-native-execution/presto_cpp/main/types/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ target_link_libraries(presto_types presto_type_converter velox_type_fbhive

set_property(TARGET presto_types PROPERTY JOB_POOL_LINK presto_link_job_pool)

add_library(presto_function_metadata OBJECT JsonFunctionMetadata.cpp)

target_link_libraries(presto_function_metadata velox_function_registry)

if(PRESTO_ENABLE_TESTING)
add_subdirectory(tests)
endif()
221 changes: 221 additions & 0 deletions presto-native-execution/presto_cpp/main/types/JsonFunctionMetadata.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "presto_cpp/main/types/JsonFunctionMetadata.h"

using namespace facebook::velox;
using namespace facebook::velox::exec;

namespace facebook::presto {

namespace {

constexpr char const* kDefaultSchema = "default";

// The keys in velox function maps are of the format
// catalog.schema.function_name. This utility function extracts the
// function_name from this string.
const std::string getFunctionName(const std::string& registeredFunctionName) {
std::vector<std::string> pieces;
folly::split('.', registeredFunctionName, pieces, true);
return pieces.back();
}

const protocol::AggregationFunctionMetadata getAggregationFunctionMetadata(
const AggregateFunctionSignature& aggregateFunctionSignature) {
protocol::AggregationFunctionMetadata aggregationFunctionMetadata;
aggregationFunctionMetadata.intermediateType =
aggregateFunctionSignature.intermediateType().toString();
/// TODO: Set to true for now. To be read from an existing mapping of
/// aggregate to order sensitivity which needs to be added.
aggregationFunctionMetadata.isOrderSensitive = true;
return aggregationFunctionMetadata;
}

const protocol::RoutineCharacteristics getRoutineCharacteristics(
const FunctionSignature& functionSignature,
const std::string& functionName,
const protocol::FunctionKind& functionKind) {
protocol::Determinism determinism;
protocol::NullCallClause nullCallClause;
if (functionKind == protocol::FunctionKind::SCALAR) {
auto functionMetadata = getFunctionMetadata(functionName);
determinism = functionMetadata.deterministic
? protocol::Determinism::DETERMINISTIC
: protocol::Determinism::NOT_DETERMINISTIC;
nullCallClause = functionMetadata.defaultNullBehavior
? protocol::NullCallClause::RETURNS_NULL_ON_NULL_INPUT
: protocol::NullCallClause::CALLED_ON_NULL_INPUT;
} else {
// Default metadata values of DETERMINISTIC and CALLED_ON_NULL_INPUT for
// non-scalar functions.
determinism = protocol::Determinism::DETERMINISTIC;
nullCallClause = protocol::NullCallClause::CALLED_ON_NULL_INPUT;
}

protocol::RoutineCharacteristics routineCharacteristics;
routineCharacteristics.language =
std::make_shared<protocol::Language>(protocol::Language({"CPP"}));
routineCharacteristics.determinism =
std::make_shared<protocol::Determinism>(determinism);
routineCharacteristics.nullCallClause =
std::make_shared<protocol::NullCallClause>(nullCallClause);
return routineCharacteristics;
}

void updateFunctionMetadata(
const std::string& functionName,
const FunctionSignature& functionSignature,
protocol::JsonBasedUdfFunctionMetadata& jsonBasedUdfFunctionMetadata) {
jsonBasedUdfFunctionMetadata.docString = functionName;
jsonBasedUdfFunctionMetadata.schema = kDefaultSchema;
jsonBasedUdfFunctionMetadata.outputType =
functionSignature.returnType().toString();

const std::vector<TypeSignature> argumentTypes =
functionSignature.argumentTypes();
std::vector<std::string> paramTypes;
for (const auto& argumentType : argumentTypes) {
paramTypes.emplace_back(argumentType.toString());
}
jsonBasedUdfFunctionMetadata.paramTypes = paramTypes;
}

const std::vector<protocol::JsonBasedUdfFunctionMetadata>
getAggregateFunctionMetadata(
const std::string& functionName,
const std::vector<AggregateFunctionSignaturePtr>&
aggregateFunctionSignatures) {
std::vector<protocol::JsonBasedUdfFunctionMetadata>
jsonBasedUdfFunctionMetadataList;
jsonBasedUdfFunctionMetadataList.reserve(aggregateFunctionSignatures.size());
const protocol::FunctionKind functionKind = protocol::FunctionKind::AGGREGATE;

for (const auto& aggregateFunctionSignature : aggregateFunctionSignatures) {
protocol::JsonBasedUdfFunctionMetadata jsonBasedUdfFunctionMetadata;
jsonBasedUdfFunctionMetadata.functionKind = functionKind;
jsonBasedUdfFunctionMetadata.routineCharacteristics =
getRoutineCharacteristics(
*aggregateFunctionSignature, functionName, functionKind);
jsonBasedUdfFunctionMetadata.aggregateMetadata =
std::make_shared<protocol::AggregationFunctionMetadata>(
getAggregationFunctionMetadata(*aggregateFunctionSignature));

updateFunctionMetadata(
functionName,
*aggregateFunctionSignature,
jsonBasedUdfFunctionMetadata);
jsonBasedUdfFunctionMetadataList.emplace_back(jsonBasedUdfFunctionMetadata);
}
return jsonBasedUdfFunctionMetadataList;
}

const std::vector<protocol::JsonBasedUdfFunctionMetadata>
getScalarFunctionMetadata(
const std::string& functionName,
const std::vector<const FunctionSignature*>& functionSignatures) {
std::vector<protocol::JsonBasedUdfFunctionMetadata>
jsonBasedUdfFunctionMetadataList;
jsonBasedUdfFunctionMetadataList.reserve(functionSignatures.size());
const protocol::FunctionKind functionKind = protocol::FunctionKind::SCALAR;

for (const auto& functionSignature : functionSignatures) {
protocol::JsonBasedUdfFunctionMetadata jsonBasedUdfFunctionMetadata;
jsonBasedUdfFunctionMetadata.functionKind = functionKind;
jsonBasedUdfFunctionMetadata.routineCharacteristics =
getRoutineCharacteristics(
*functionSignature, functionName, functionKind);

updateFunctionMetadata(
functionName, *functionSignature, jsonBasedUdfFunctionMetadata);
jsonBasedUdfFunctionMetadataList.emplace_back(jsonBasedUdfFunctionMetadata);
}
return jsonBasedUdfFunctionMetadataList;
}

const std::vector<protocol::JsonBasedUdfFunctionMetadata>
getWindowFunctionMetadata(
const std::string& functionName,
const std::vector<FunctionSignaturePtr>& windowFunctionSignatures) {
std::vector<protocol::JsonBasedUdfFunctionMetadata>
jsonBasedUdfFunctionMetadataList;
jsonBasedUdfFunctionMetadataList.reserve(windowFunctionSignatures.size());
const protocol::FunctionKind functionKind = protocol::FunctionKind::WINDOW;

for (const auto& windowFunctionSignature : windowFunctionSignatures) {
protocol::JsonBasedUdfFunctionMetadata jsonBasedUdfFunctionMetadata;
jsonBasedUdfFunctionMetadata.functionKind = functionKind;
jsonBasedUdfFunctionMetadata.routineCharacteristics =
getRoutineCharacteristics(
*windowFunctionSignature, functionName, functionKind);

updateFunctionMetadata(
functionName, *windowFunctionSignature, jsonBasedUdfFunctionMetadata);
jsonBasedUdfFunctionMetadataList.emplace_back(jsonBasedUdfFunctionMetadata);
}
return jsonBasedUdfFunctionMetadataList;
}

const std::vector<protocol::JsonBasedUdfFunctionMetadata> getFunctionMetadata(
const std::string& functionName) {
if (auto aggregateFunctionSignatures =
getAggregateFunctionSignatures(functionName)) {
return getAggregateFunctionMetadata(
functionName, aggregateFunctionSignatures.value());
} else if (
auto windowFunctionSignatures =
getWindowFunctionSignatures(functionName)) {
return getWindowFunctionMetadata(
functionName, windowFunctionSignatures.value());
} else {
auto functionSignatures = getFunctionSignatures();
if (functionSignatures.find(functionName) != functionSignatures.end()) {
return getScalarFunctionMetadata(
functionName, functionSignatures.at(functionName));
}
}
VELOX_UNREACHABLE("Function kind cannot be determined");
}

} // namespace

void getJsonMetadataForFunction(
const std::string& registeredFunctionName,
nlohmann::json& jsonMetadataList) {
auto functionMetadataList = getFunctionMetadata(registeredFunctionName);
auto functionName = getFunctionName(registeredFunctionName);
for (const auto& functionMetadata : functionMetadataList) {
jsonMetadataList[functionName].emplace_back(functionMetadata);
}
}

json getJsonFunctionMetadata() {
auto registeredFunctionNames = functions::getSortedAggregateNames();
auto scalarFunctionNames = functions::getSortedScalarNames();
for (const auto& scalarFunction : scalarFunctionNames) {
registeredFunctionNames.emplace_back(scalarFunction);
}
auto windowFunctionNames = functions::getSortedWindowNames();
for (const auto& windowFunction : windowFunctionNames) {
registeredFunctionNames.emplace_back(windowFunction);
}

nlohmann::json j;
for (const auto& registeredFunctionName : registeredFunctionNames) {
getJsonMetadataForFunction(registeredFunctionName, j);
}
return j;
}

} // namespace facebook::presto
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "velox/exec/Aggregate.h"
#include "velox/exec/AggregateFunctionRegistry.h"
#include "velox/exec/WindowFunction.h"
#include "velox/expression/SimpleFunctionRegistry.h"
#include "velox/functions/CoverageUtil.h"
#include "velox/functions/FunctionRegistry.h"

#include "presto_cpp/presto_protocol/presto_protocol.h"

namespace facebook::presto {

void getJsonMetadataForFunction(
const std::string& functionName,
nlohmann::json& jsonMetadataList);

json getJsonFunctionMetadata();

} // namespace facebook::presto
21 changes: 21 additions & 0 deletions presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,24 @@ target_link_libraries(
velox_tpch_connector
gtest
gtest_main)

add_executable(presto_function_metadata_test JsonFunctionMetadataTest.cpp)

add_test(
NAME presto_function_metadata_test
COMMAND presto_function_metadata_test
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})

target_link_libraries(
presto_function_metadata_test
gtest
gtest_main
presto_function_metadata
presto_protocol
velox_aggregates
velox_coverage_util
velox_exec
velox_functions_prestosql
presto_type_converter
velox_window
${ANTLR4_RUNTIME})
Loading

0 comments on commit 6386575

Please sign in to comment.