From 26bec4954528445a5b041c83a520c99db361da38 Mon Sep 17 00:00:00 2001 From: zhaojunnana <44078666+zhaojunnana@users.noreply.github.com> Date: Mon, 13 Feb 2023 17:23:42 +0800 Subject: [PATCH] UDF based on existing function infra (#4804) Added UDF support using C++. --------- Co-authored-by: zhaojunnan Co-authored-by: Wey Gu Co-authored-by: Cheng Xuntao <7731943+xtcyclist@users.noreply.github.com> --- conf/nebula-graphd.conf.default | 7 + src/common/function/CMakeLists.txt | 2 + src/common/function/FunctionManager.cpp | 13 ++ src/common/function/FunctionManager.h | 2 +- src/common/function/FunctionUdfManager.cpp | 204 +++++++++++++++++++++ src/common/function/FunctionUdfManager.h | 40 ++++ src/common/function/GraphFunction.h | 38 ++++ src/graph/service/GraphFlags.h | 2 + udf/Makefile | 25 +++ udf/standard_deviation.cpp | 91 +++++++++ udf/standard_deviation.h | 60 ++++++ 11 files changed, 483 insertions(+), 1 deletion(-) create mode 100644 src/common/function/FunctionUdfManager.cpp create mode 100644 src/common/function/FunctionUdfManager.h create mode 100644 src/common/function/GraphFunction.h create mode 100644 udf/Makefile create mode 100644 udf/standard_deviation.cpp create mode 100644 udf/standard_deviation.h diff --git a/conf/nebula-graphd.conf.default b/conf/nebula-graphd.conf.default index a528daa5195..c937b665ac6 100644 --- a/conf/nebula-graphd.conf.default +++ b/conf/nebula-graphd.conf.default @@ -97,6 +97,12 @@ # if use balance data feature, only work if enable_experimental_feature is true --enable_data_balance=true +# enable udf, written in c++ only for now +--enable_udf=true + +# set the directory where the .so files of udf are stored, when enable_udf is true +--udf_path=/home/nebula/dev/nebula/udf/ + ########## session ########## # Maximum number of sessions that can be created per IP and per user --max_sessions_per_ip_per_user=300 @@ -116,3 +122,4 @@ --memory_purge_enabled=true # memory background purge interval in seconds --memory_purge_interval_seconds=10 + diff --git a/src/common/function/CMakeLists.txt b/src/common/function/CMakeLists.txt index 2bf6674cff9..5d9f2442dba 100644 --- a/src/common/function/CMakeLists.txt +++ b/src/common/function/CMakeLists.txt @@ -6,6 +6,8 @@ nebula_add_library( function_manager_obj OBJECT FunctionManager.cpp ../geo/GeoFunction.cpp + FunctionUdfManager.cpp + GraphFunction.h ) nebula_add_library( diff --git a/src/common/function/FunctionManager.cpp b/src/common/function/FunctionManager.cpp index 0d3bb3475ac..c39ad8887fe 100644 --- a/src/common/function/FunctionManager.cpp +++ b/src/common/function/FunctionManager.cpp @@ -10,6 +10,7 @@ #include #include +#include "FunctionUdfManager.h" #include "common/base/Base.h" #include "common/datatypes/DataSet.h" #include "common/datatypes/Edge.h" @@ -28,12 +29,18 @@ #include "common/thrift/ThriftTypes.h" #include "common/time/TimeUtils.h" #include "common/time/WallClock.h" +#include "graph/service/GraphFlags.h" + +DEFINE_bool(enable_udf, false, "enable udf"); namespace nebula { // static FunctionManager &FunctionManager::instance() { static FunctionManager instance; + if (FLAGS_enable_udf) { + static FunctionUdfManager udfManager; + } return instance; } @@ -440,6 +447,9 @@ StatusOr FunctionManager::getReturnType(const std::string &funcName } auto iter = typeSignature_.find(func); if (iter == typeSignature_.end()) { + if (FLAGS_enable_udf) { + return FunctionUdfManager::getUdfReturnType(funcName, argsType); + } return Status::Error("Function `%s' not defined", funcName.c_str()); } @@ -2930,6 +2940,9 @@ Status FunctionManager::find(const std::string &func, const size_t arity) { std::transform(func.begin(), func.end(), func.begin(), ::tolower); auto iter = functions_.find(func); if (iter == functions_.end()) { + if (FLAGS_enable_udf) { + return FunctionUdfManager::loadUdfFunction(func, arity); + } return Status::Error("Function `%s' not defined", func.c_str()); } // check arity diff --git a/src/common/function/FunctionManager.h b/src/common/function/FunctionManager.h index 8b4014bf66f..0204081dd71 100644 --- a/src/common/function/FunctionManager.h +++ b/src/common/function/FunctionManager.h @@ -63,7 +63,6 @@ class FunctionManager final { static StatusOr getReturnType(const std::string &funcName, const std::vector &argsType); - private: // The attributes of the function call struct FunctionAttributes final { size_t minArity_{0}; @@ -89,6 +88,7 @@ class FunctionManager final { } } + private: /** * FunctionManager functions as a singleton, since the dynamic loading is * process-wide. diff --git a/src/common/function/FunctionUdfManager.cpp b/src/common/function/FunctionUdfManager.cpp new file mode 100644 index 00000000000..e2ad870d3ee --- /dev/null +++ b/src/common/function/FunctionUdfManager.cpp @@ -0,0 +1,204 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +#include "FunctionUdfManager.h" + +#include +#include + +#include +#include + +#include "graph/service/GraphFlags.h" + +DEFINE_string(udf_path, "lib/udf", "path to hold the udf"); + +namespace nebula { + +static const char *dlsym_error; +static std::unordered_map udfFunReturnType_; +static std::unordered_map>> + udfFunInputType_; +std::unordered_map udfFunctions_; + +std::atomic expired_{}; +std::atomic try_to_expire_{}; +std::mutex mutex_; +std::condition_variable expired_cond_; + +FunctionUdfManager &FunctionUdfManager::instance() { + static FunctionUdfManager instance; + return instance; +} + +std::vector getFilesList(const std::string &path, const char *ftype) { + std::vector filenames; + DIR *pDir; + struct dirent *ptr; + if (!(pDir = opendir(path.c_str()))) { + LOG(ERROR) << "UDF Folder doesn't Exist!" << dlsym_error; + return filenames; + } + while ((ptr = readdir(pDir)) != 0) { + if (strcmp(ptr->d_name, ".") != 0 && strcmp(ptr->d_name, "..") != 0 && + strcmp((ptr->d_name) + strlen(ptr->d_name) - strlen(ftype), ftype) == 0) { + filenames.emplace_back(ptr->d_name); + LOG(INFO) << "Load UDF SO Name: " << ptr->d_name; + } + } + closedir(pDir); + return filenames; +} + +FunctionUdfManager::create_f *FunctionUdfManager::getGraphFunctionClass(void *func_handle) { + auto *create_func = reinterpret_cast(dlsym(func_handle, "create")); + dlsym_error = dlerror(); + if (dlsym_error) { + LOG(ERROR) << "Cannot load symbol create: " << dlsym_error; + } + return create_func; +} + +FunctionUdfManager::destroy_f *FunctionUdfManager::deleteGraphFunctionClass(void *func_handle) { + auto *destroy_func = reinterpret_cast(dlsym(func_handle, "destroy")); + dlsym_error = dlerror(); + if (dlsym_error) { + LOG(ERROR) << "Cannot load symbol destroy: " << dlsym_error; + } + return destroy_func; +} + +FunctionUdfManager::FunctionUdfManager() { + initAndLoadSoFunction(); + expired_ = true; + try_to_expire_ = false; + + std::thread([this]() { + while (!try_to_expire_) { + std::this_thread::sleep_for(std::chrono::seconds(300)); + initAndLoadSoFunction(); + } + { + std::lock_guard locker(mutex_); + expired_ = true; + expired_cond_.notify_one(); + } + }).detach(); +} + +void FunctionUdfManager::initAndLoadSoFunction() { + auto udfPath = FLAGS_udf_path; + LOG(INFO) << "Load UDF so library: " << udfPath; + std::vector files = getFilesList(udfPath, ".so"); + + for (auto &file : files) { + const std::string &path = udfPath; + std::string so_path_string = path + file; + const char *soPath = so_path_string.c_str(); + try { + void *func_handle = dlopen(soPath, RTLD_LAZY); + if (!func_handle) { + LOG(ERROR) << "Cannot load udf library: " << dlerror(); + } + dlerror(); + + create_f *create_func = getGraphFunctionClass(func_handle); + destroy_f *destroy_func = deleteGraphFunctionClass(func_handle); + if (create_func == nullptr || destroy_func == nullptr) { + LOG(ERROR) << "GraphFunction Create Or Destroy Error: " << soPath; + break; + } + + GraphFunction *gf = create_func(); + char *funName = gf->name(); + udfFunInputType_.emplace(funName, gf->inputType()); + udfFunReturnType_.emplace(funName, gf->returnType()); + addSoUdfFunction(funName, soPath, gf->minArity(), gf->maxArity(), gf->isPure()); + + destroy_func(gf); + dlclose(func_handle); + } catch (...) { + LOG(ERROR) << "load So library Error: " << soPath; + } + } +} + +StatusOr FunctionUdfManager::getUdfReturnType( + std::string func, const std::vector &argsType) { + if (udfFunReturnType_.find(func) != udfFunReturnType_.end()) { + if (udfFunInputType_.find(func) != udfFunInputType_.end()) { + auto iter = udfFunInputType_.find(func); + for (const auto &args : iter->second) { + if (argsType == args || args[0] == Value::Type::NULLVALUE || + args[0] == Value::Type::__EMPTY__) { + return udfFunReturnType_[func]; + } + } + } + return Status::Error("Parameter's type error"); + } + return Status::Error("Function `%s' not defined", func.c_str()); +} + +StatusOr nebula::FunctionUdfManager::loadUdfFunction( + std::string func, size_t arity) { + auto iter = udfFunctions_.find(func); + if (iter == udfFunctions_.end()) { + return Status::Error("Function `%s' not defined", func.c_str()); + } + auto minArity = iter->second.minArity_; + auto maxArity = iter->second.maxArity_; + if (arity < minArity || arity > maxArity) { + if (minArity == maxArity) { + return Status::Error( + "Arity not match for function `%s': " + "provided %lu but %lu expected.", + func.c_str(), + arity, + minArity); + } else { + return Status::Error( + "Arity not match for function `%s': " + "provided %lu but %lu-%lu expected.", + func.c_str(), + arity, + minArity, + maxArity); + } + } + return iter->second; +} + +void FunctionUdfManager::addSoUdfFunction( + char *funName, const char *soPath, size_t minArity, size_t maxArity, bool isPure) { + auto &attr = udfFunctions_[funName]; + attr.minArity_ = minArity; + attr.maxArity_ = maxArity; + attr.isAlwaysPure_ = isPure; + std::string path = soPath; + attr.body_ = [path](const auto &args) -> Value { + try { + char *soPath2 = const_cast(path.c_str()); + void *func_handle = dlopen(soPath2, RTLD_LAZY); + if (!func_handle) { + LOG(ERROR) << "Cannot load udf library: " << dlerror(); + } + dlerror(); + + create_f *create_func = getGraphFunctionClass(func_handle); + destroy_f *destroy_func = deleteGraphFunctionClass(func_handle); + + GraphFunction *gf = create_func(); + Value res = gf->body(args); + destroy_func(gf); + dlclose(func_handle); + return res; + } catch (...) { + return Value::kNullBadData; + } + }; +} + +} // namespace nebula diff --git a/src/common/function/FunctionUdfManager.h b/src/common/function/FunctionUdfManager.h new file mode 100644 index 00000000000..2cd1c8e6f2f --- /dev/null +++ b/src/common/function/FunctionUdfManager.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +#ifndef COMMON_FUNCTION_FUNCTIONUDFMANAGER_H_ +#define COMMON_FUNCTION_FUNCTIONUDFMANAGER_H_ + +#include "FunctionManager.h" +#include "GraphFunction.h" + +namespace nebula { + +class FunctionManager; + +class FunctionUdfManager { + public: + typedef GraphFunction *(create_f)(); + typedef void(destroy_f)(GraphFunction *); + + static StatusOr getUdfReturnType(const std::string functionName, + const std::vector &argsType); + + static StatusOr loadUdfFunction( + std::string functionName, size_t arity); + + static FunctionUdfManager &instance(); + + FunctionUdfManager(); + + private: + static create_f *getGraphFunctionClass(void *func_handle); + static destroy_f *deleteGraphFunctionClass(void *func_handle); + + void addSoUdfFunction(char *funName, const char *soPath, size_t i, size_t i1, bool b); + void initAndLoadSoFunction(); +}; + +} // namespace nebula +#endif // COMMON_FUNCTION_FUNCTIONUDFMANAGER_H_ diff --git a/src/common/function/GraphFunction.h b/src/common/function/GraphFunction.h new file mode 100644 index 00000000000..4848b160846 --- /dev/null +++ b/src/common/function/GraphFunction.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +#ifndef COMMON_FUNCTION_GRAPHFUNCTION_H +#define COMMON_FUNCTION_GRAPHFUNCTION_H + +#include + +#include "common/datatypes/Value.h" + +class GraphFunction; + +extern "C" GraphFunction *create(); +extern "C" void destroy(GraphFunction *function); + +class GraphFunction { + public: + virtual ~GraphFunction() = default; + + virtual char *name() = 0; + + virtual std::vector> inputType() = 0; + + virtual nebula::Value::Type returnType() = 0; + + virtual size_t minArity() = 0; + + virtual size_t maxArity() = 0; + + virtual bool isPure() = 0; + + virtual nebula::Value body( + const std::vector> &args) = 0; +}; + +#endif // COMMON_FUNCTION_GRAPHFUNCTION_H diff --git a/src/graph/service/GraphFlags.h b/src/graph/service/GraphFlags.h index 3ac25c64f11..2aa68c0e0d5 100644 --- a/src/graph/service/GraphFlags.h +++ b/src/graph/service/GraphFlags.h @@ -22,6 +22,8 @@ DECLARE_int32(listen_backlog); DECLARE_string(listen_netdev); DECLARE_string(local_ip); DECLARE_string(pid_file); +DECLARE_bool(enable_udf); +DECLARE_string(udf_path); DECLARE_bool(local_config); DECLARE_bool(accept_partial_success); DECLARE_bool(disable_octal_escape_char); diff --git a/udf/Makefile b/udf/Makefile new file mode 100644 index 00000000000..28e4babda87 --- /dev/null +++ b/udf/Makefile @@ -0,0 +1,25 @@ +# Copyright (c) 2020 vesoft inc. All rights reserved. +# +# This source code is licensed under Apache 2.0 License. +# + +ifneq ($(wildcard ../build/third-party/install),) + 3PP_PATH := ../build/third-party/install +else ifneq ($(wildcard /opt/vesoft/third-party/3.3),) + 3PP_PATH := /opt/vesoft/third-party/3.3 +else ifneq ($(wildcard /opt/vesoft/third-party/3.0),) + 3PP_PATH := /opt/vesoft/third-party/3.0 +else + $(error "Cannot find the third-party installation directory") +endif + +CXX := g++ +CXX_FLAGS := -c -I ../src/ -I $(3PP_PATH)/include/ -fPIC + +all: standard_deviation.cpp + $(CXX) $(CXX_FLAGS) standard_deviation.cpp -o standard_deviation.o + $(CXX) -shared -o standard_deviation.so standard_deviation.o + +clean: + rm ./*.o + rm ./*.so diff --git a/udf/standard_deviation.cpp b/udf/standard_deviation.cpp new file mode 100644 index 00000000000..7be5e157249 --- /dev/null +++ b/udf/standard_deviation.cpp @@ -0,0 +1,91 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +#include "standard_deviation.h" + +#include +#include + +#include "../src/common/datatypes/List.h" + +extern "C" GraphFunction *create() { + return new standard_deviation; +} +extern "C" void destroy(GraphFunction *function) { + delete function; +} + +char *standard_deviation::name() { + const char *name = "standard_deviation"; + return const_cast(name); +} + +std::vector> standard_deviation::inputType() { + std::vector vtp = {nebula::Value::Type::LIST}; + std::vector> vvtp = {vtp}; + return vvtp; +} + +nebula::Value::Type standard_deviation::returnType() { + return nebula::Value::Type::FLOAT; +} + +size_t standard_deviation::minArity() { + return 1; +} + +size_t standard_deviation::maxArity() { + return 1; +} + +bool standard_deviation::isPure() { + return true; +} + +double caculate_standard_deviation(const std::vector &numbers) { + double sum = 0; + for (double number : numbers) { + sum += number; + } + double average = sum / numbers.size(); + + double variance = 0; + for (double number : numbers) { + double difference = number - average; + variance += difference * difference; + } + variance /= numbers.size(); + + return sqrt(variance); +} + +nebula::Value standard_deviation::body( + const std::vector> &args) { + switch (args[0].get().type()) { + case nebula::Value::Type::NULLVALUE: { + return nebula::Value::kNullValue; + } + case nebula::Value::Type::LIST: { + std::vector numbers; + auto list = args[0].get().getList(); + auto size = list.size(); + + for (int i = 0; i < size; i++) { + auto &value = list[i]; + if (value.isInt()) { + numbers.push_back(value.getInt()); + } else if (value.isFloat()) { + numbers.push_back(value.getFloat()); + } else { + return nebula::Value::kNullValue; + } + } + return nebula::Value(caculate_standard_deviation(numbers)); + } + default: { + return nebula::Value::kNullValue; + } + } +} diff --git a/udf/standard_deviation.h b/udf/standard_deviation.h new file mode 100644 index 00000000000..769e66423c4 --- /dev/null +++ b/udf/standard_deviation.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +#ifndef UDF_PROJECT_STANDARD_DEVIATION_H +#define UDF_PROJECT_STANDARD_DEVIATION_H + +#include "../src/common/function/GraphFunction.h" + +// Example of a UDF function that calculates the standard deviation of a set of numbers. +// > YIELD standard_deviation([1,2,3]) +// +-----------------------------+ +// | standard_deviation([1,2,3]) | +// +-----------------------------+ +// | 0.816496580927726 | +// +-----------------------------+ + +// > YIELD standard_deviation([1,1,1]) +// +-----------------------------+ +// | standard_deviation([1,1,1]) | +// +-----------------------------+ +// | 0.0 | +// +-----------------------------+ + +// > GO 1 TO 2 STEPS FROM "player100" OVER follow +// YIELD properties(edge).degree AS d | YIELD collect($-.d) +// +--------------------------+ +// | collect($-.d) | +// +--------------------------+ +// | [95, 95, 95, 90, 95, 90] | +// +--------------------------+ + +// > GO 1 TO 2 STEPS FROM "player100" OVER follow +// YIELD properties(edge).degree AS d | YIELD collect($-.d) AS d | +// YIELD standard_deviation($-.d) +// +--------------------------+ +// | standard_deviation($-.d) | +// +--------------------------+ +// | 2.357022603955158 | +// +--------------------------+ + +class standard_deviation : public GraphFunction { + public: + char *name() override; + + std::vector> inputType() override; + + nebula::Value::Type returnType() override; + + size_t minArity() override; + + size_t maxArity() override; + + bool isPure() override; + + nebula::Value body(const std::vector> &args) override; +}; + +#endif // UDF_PROJECT_STANDARD_DEVIATION_H