From e307fd939b5dc7346964fd10358f8b2b29417750 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 19 Jun 2024 13:20:08 +0800 Subject: [PATCH 1/9] build(deps-dev): bump urllib3 from 1.26.18 to 1.26.19 in /docs (#3948) Bumps [urllib3](https://github.com/urllib3/urllib3) from 1.26.18 to 1.26.19. - [Release notes](https://github.com/urllib3/urllib3/releases) - [Changelog](https://github.com/urllib3/urllib3/blob/1.26.19/CHANGES.rst) - [Commits](https://github.com/urllib3/urllib3/compare/1.26.18...1.26.19) --- updated-dependencies: - dependency-name: urllib3 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/poetry.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/poetry.lock b/docs/poetry.lock index ca6c9ccb8c6..39577275304 100644 --- a/docs/poetry.lock +++ b/docs/poetry.lock @@ -670,13 +670,13 @@ test = ["coverage", "pytest", "pytest-cov"] [[package]] name = "urllib3" -version = "1.26.18" +version = "1.26.19" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ - {file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"}, - {file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"}, + {file = "urllib3-1.26.19-py2.py3-none-any.whl", hash = "sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3"}, + {file = "urllib3-1.26.19.tar.gz", hash = "sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429"}, ] [package.extras] From 818d292cc5190c65ef3ead678f6742e828d0f163 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Wed, 26 Jun 2024 12:34:11 +0800 Subject: [PATCH 2/9] feat(udf): isin (#3939) --- cases/query/udf_query.yaml | 19 ++++++++++ hybridse/src/udf/default_defs/array_def.cc | 41 ++++++++++++++++++++-- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/cases/query/udf_query.yaml b/cases/query/udf_query.yaml index fefe1380dbb..ee9cad2d667 100644 --- a/cases/query/udf_query.yaml +++ b/cases/query/udf_query.yaml @@ -536,6 +536,25 @@ cases: data: | true, true, false, false, true, false, true, false, true, false, true + - id: isin + mode: request-unsupport + inputs: + - name: t1 + columns: ["col1:int32", "std_ts:timestamp", "col2:string"] + indexs: ["index1:col1:std_ts"] + rows: + - [1, 1590115420001, "ABCabcabc"] + sql: | + select + isin(2, [2,2]) as c0, + isin(cast(3 as int64), ARRAY[NULL, 1, 2]) as c1 + expect: + columns: + - c0 bool + - c1 bool + data: | + true, false + - id: array_split mode: request-unsupport inputs: diff --git a/hybridse/src/udf/default_defs/array_def.cc b/hybridse/src/udf/default_defs/array_def.cc index a1f35f38e35..b5c36bc3d7e 100644 --- a/hybridse/src/udf/default_defs/array_def.cc +++ b/hybridse/src/udf/default_defs/array_def.cc @@ -37,8 +37,30 @@ struct ArrayContains { // - bool/intxx/float/double -> bool/intxx/float/double // - Timestamp/Date/StringRef -> Timestamp*/Date*/StringRef* bool operator()(ArrayRef* arr, ParamType v, bool is_null) { - // NOTE: array_contains([null], null) returns null - // this might not expected + for (uint64_t i = 0; i < arr->size; ++i) { + if constexpr (std::is_pointer_v) { + // null or same value returns true + if ((is_null && arr->nullables[i]) || (!arr->nullables[i] && *arr->raw[i] == *v)) { + return true; + } + } else { + if ((is_null && arr->nullables[i]) || (!arr->nullables[i] && arr->raw[i] == v)) { + return true; + } + } + } + return false; + } +}; + +template +struct IsIn { + // udf registry types + using Args = std::tuple, ArrayRef>; + + using ParamType = typename DataTypeTrait::CCallArgType; + + bool operator()(ParamType v, bool is_null, ArrayRef* arr) { for (uint64_t i = 0; i < arr->size; ++i) { if constexpr (std::is_pointer_v) { // null or same value returns true @@ -98,6 +120,21 @@ void DefaultUdfLibrary::InitArrayUdfs() { @since 0.7.0 )"); + RegisterExternalTemplate("isin") + .args_in() + .doc(R"( + @brief isin(value, array) - Returns true if the array contains the value. + + Example: + + @code{.sql} + select isin(2, [2,2]) as c0; + -- output true + @endcode + + @since 0.9.1 + )"); + RegisterExternal("split_array") .returns>() .return_by_arg(true) From 6b06e38a7c8e4f794503afb8b8ccfab8a1b96297 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Wed, 26 Jun 2024 12:52:41 +0800 Subject: [PATCH 3/9] feat(#3916): support @@execute_mode = 'request' (#3924) --- hybridse/include/vm/engine.h | 4 ++ hybridse/src/vm/engine.cc | 45 +++++++++++++++++------ hybridse/src/vm/engine_context.cc | 3 +- src/client/tablet_client.cc | 3 +- src/client/tablet_client.h | 2 +- src/sdk/internal/system_variable.cc | 57 +++++++++++++++++++++++++++++ src/sdk/internal/system_variable.h | 40 ++++++++++++++++++++ src/sdk/sql_cluster_router.cc | 37 ++++++++++++++----- src/sdk/sql_cluster_router.h | 5 ++- src/sdk/sql_router.h | 2 +- src/sdk/sql_router_sdk.i | 1 + src/tablet/tablet_impl.cc | 8 +++- 12 files changed, 178 insertions(+), 29 deletions(-) create mode 100644 src/sdk/internal/system_variable.cc create mode 100644 src/sdk/internal/system_variable.h diff --git a/hybridse/include/vm/engine.h b/hybridse/include/vm/engine.h index 09586a8b03d..7e183d43c33 100644 --- a/hybridse/include/vm/engine.h +++ b/hybridse/include/vm/engine.h @@ -429,6 +429,10 @@ class Engine { /// request row info exists in 'values' option, as a format of: /// 1. [(col1_expr, col2_expr, ... ), (...), ...] /// 2. (col1_expr, col2_expr, ... ) + // + // This function only check on request/batchrequest mode, for batch mode it does nothing. + // As for old-fashioned usage, request row does not need to appear in SQL, so it won't report + // error even request rows is empty, instead checks should performed at the very beginning of Compute. static absl::Status ExtractRequestRowsInSQL(SqlContext* ctx); std::shared_ptr GetCacheLocked(const std::string& db, diff --git a/hybridse/src/vm/engine.cc b/hybridse/src/vm/engine.cc index 5c179fb79b6..a2aaa30f305 100644 --- a/hybridse/src/vm/engine.cc +++ b/hybridse/src/vm/engine.cc @@ -462,17 +462,27 @@ int32_t RequestRunSession::Run(const uint32_t task_id, const Row& in_row, Row* o if (!sql_request_rows.empty()) { row = sql_request_rows.at(0); } - auto task = std::dynamic_pointer_cast(compile_info_) - ->get_sql_context() - .cluster_job->GetTask(task_id) - .GetRoot(); + + auto info = std::dynamic_pointer_cast(compile_info_); + auto main_task_id = info->GetClusterJob()->main_task_id(); + if (task_id == main_task_id && !info->GetRequestSchema().empty() && row.empty()) { + // a non-empty request row required but it not. + // checks only happen for a top level query, not subquery, + // since query internally may construct a empty row as input row, + // not meaning row with no columns, but row with all column values NULL + LOG(WARNING) << "request SQL requires a non-empty request row, but empty row received"; + // TODO(someone): use status + return common::StatusCode::kRunSessionError; + } + + auto task = info->get_sql_context().cluster_job->GetTask(task_id).GetRoot(); + if (nullptr == task) { LOG(WARNING) << "fail to run request plan: taskid" << task_id << " not exist!"; return -2; } DLOG(INFO) << "Request Row Run with task_id " << task_id; - RunnerContext ctx(std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job, row, - sp_name_, is_debug_); + RunnerContext ctx(info->get_sql_context().cluster_job, row, sp_name_, is_debug_); auto output = task->RunWithCache(ctx); if (!output) { LOG(WARNING) << "Run request plan output is null"; @@ -491,13 +501,24 @@ int32_t BatchRequestRunSession::Run(const std::vector& request_batch, std:: } int32_t BatchRequestRunSession::Run(const uint32_t id, const std::vector& request_batch, std::vector& output) { - std::vector<::hybridse::codec::Row>& sql_request_rows = - std::dynamic_pointer_cast(GetCompileInfo())->get_sql_context().request_rows; + auto info = std::dynamic_pointer_cast(GetCompileInfo()); + std::vector<::hybridse::codec::Row>& sql_request_rows = info->get_sql_context().request_rows; + + std::vector<::hybridse::codec::Row> rows = sql_request_rows; + if (rows.empty()) { + rows = request_batch; + } + + auto main_task_id = info->GetClusterJob()->main_task_id(); + if (id != main_task_id && !info->GetRequestSchema().empty() && rows.empty()) { + // a non-empty request row list required but it not + LOG(WARNING) << "batchrequest SQL requires a non-empty request row list, but empty row list received"; + // TODO(someone): use status + return common::StatusCode::kRunSessionError; + } - RunnerContext ctx(std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job, - sql_request_rows.empty() ? request_batch : sql_request_rows, sp_name_, is_debug_); - auto task = - std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job->GetTask(id).GetRoot(); + RunnerContext ctx(info->get_sql_context().cluster_job, rows, sp_name_, is_debug_); + auto task = info->get_sql_context().cluster_job->GetTask(id).GetRoot(); if (nullptr == task) { LOG(WARNING) << "Fail to run request plan: taskid" << id << " not exist!"; return -2; diff --git a/hybridse/src/vm/engine_context.cc b/hybridse/src/vm/engine_context.cc index 570726aa0eb..47ff39437d6 100644 --- a/hybridse/src/vm/engine_context.cc +++ b/hybridse/src/vm/engine_context.cc @@ -17,6 +17,7 @@ #include "vm/engine_context.h" #include "absl/container/flat_hash_map.h" +#include "absl/strings/ascii.h" namespace hybridse { namespace vm { @@ -61,7 +62,7 @@ std::string EngineModeName(EngineMode mode) { absl::StatusOr UnparseEngineMode(absl::string_view str) { auto& m = getModeMap(); - auto it = m.find(str); + auto it = m.find(absl::AsciiStrToLower(str)); if (it != m.end()) { return it->second; } diff --git a/src/client/tablet_client.cc b/src/client/tablet_client.cc index a1dd925fcde..cbfb794817f 100644 --- a/src/client/tablet_client.cc +++ b/src/client/tablet_client.cc @@ -73,6 +73,7 @@ bool TabletClient::Query(const std::string& db, const std::string& sql, const st } bool TabletClient::Query(const std::string& db, const std::string& sql, + hybridse::vm::EngineMode default_mode, const std::vector& parameter_types, const std::string& parameter_row, brpc::Controller* cntl, ::openmldb::api::QueryResponse* response, const bool is_debug) { @@ -80,7 +81,7 @@ bool TabletClient::Query(const std::string& db, const std::string& sql, ::openmldb::api::QueryRequest request; request.set_sql(sql); request.set_db(db); - request.set_is_batch(true); + request.set_is_batch(default_mode == hybridse::vm::kBatchMode); request.set_is_debug(is_debug); request.set_parameter_row_size(parameter_row.size()); request.set_parameter_row_slices(1); diff --git a/src/client/tablet_client.h b/src/client/tablet_client.h index 177124208fc..33188adadcc 100644 --- a/src/client/tablet_client.h +++ b/src/client/tablet_client.h @@ -64,7 +64,7 @@ class TabletClient : public Client { const openmldb::common::VersionPair& pair, std::string& msg); // NOLINT - bool Query(const std::string& db, const std::string& sql, + bool Query(const std::string& db, const std::string& sql, hybridse::vm::EngineMode default_mode, const std::vector& parameter_types, const std::string& parameter_row, brpc::Controller* cntl, ::openmldb::api::QueryResponse* response, const bool is_debug = false); diff --git a/src/sdk/internal/system_variable.cc b/src/sdk/internal/system_variable.cc new file mode 100644 index 00000000000..dc818dfb412 --- /dev/null +++ b/src/sdk/internal/system_variable.cc @@ -0,0 +1,57 @@ +/** + * Copyright (c) 2024 OpenMLDB Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sdk/internal/system_variable.h" + +#include "absl/status/status.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_join.h" +#include "absl/strings/substitute.h" + +namespace openmldb { +namespace sdk { +namespace internal { + +static SystemVariables CreateSystemVariablePresets() { + SystemVariables map = { + {"execute_mode", {"online", "request", "offline"}}, + // TODO(someone): add all + }; + return map; +} + +const SystemVariables& GetSystemVariablePresets() { + static const SystemVariables& map = *new auto(CreateSystemVariablePresets()); + return map; +} +absl::Status CheckSystemVariableSet(absl::string_view key, absl::string_view val) { + auto& presets = GetSystemVariablePresets(); + auto it = presets.find(absl::AsciiStrToLower(key)); + if (it == presets.end()) { + return absl::InvalidArgumentError(absl::Substitute("key '$0' not found as a system variable", key)); + } + + if (it->second.find(absl::AsciiStrToLower(val)) == it->second.end()) { + return absl::InvalidArgumentError( + absl::Substitute("invalid value for system variable '$0', expect one of [$1], but got $2", key, + absl::StrJoin(it->second, ", "), val)); + } + + return absl::OkStatus(); +} +} // namespace internal +} // namespace sdk +} // namespace openmldb diff --git a/src/sdk/internal/system_variable.h b/src/sdk/internal/system_variable.h new file mode 100644 index 00000000000..75edd014ab7 --- /dev/null +++ b/src/sdk/internal/system_variable.h @@ -0,0 +1,40 @@ +/** + * Copyright (c) 2024 OpenMLDB Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SRC_SDK_INTERNAL_SYSTEM_VARIABLE_H_ +#define SRC_SDK_INTERNAL_SYSTEM_VARIABLE_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" + +namespace openmldb { +namespace sdk { +namespace internal { + +using SystemVariables = absl::flat_hash_map>; + +const SystemVariables& GetSystemVariablePresets(); + +// check if the stmt 'set {key} = {val}' has a valid semantic +// key and value for system variable is case insensetive +absl::Status CheckSystemVariableSet(absl::string_view key, absl::string_view val); + +} // namespace internal +} // namespace sdk +} // namespace openmldb + +#endif // SRC_SDK_INTERNAL_SYSTEM_VARIABLE_H_ diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index 068538ba5e0..dbdd7dede9d 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -56,6 +56,7 @@ #include "sdk/base.h" #include "sdk/base_impl.h" #include "sdk/batch_request_result_set_sql.h" +#include "sdk/internal/system_variable.h" #include "sdk/job_table_helper.h" #include "sdk/node_adapter.h" #include "sdk/query_future_impl.h" @@ -1215,8 +1216,8 @@ std::shared_ptr<::hybridse::sdk::ResultSet> SQLClusterRouter::ExecuteSQLParamete cntl->set_timeout_ms(options_->request_timeout); DLOG(INFO) << "send query to tablet " << client->GetEndpoint(); auto response = std::make_shared<::openmldb::api::QueryResponse>(); - if (!client->Query(db, sql, parameter_types, parameter ? parameter->GetRow() : "", cntl.get(), response.get(), - options_->enable_debug)) { + if (!client->Query(db, sql, GetDefaultEngineMode(), parameter_types, parameter ? parameter->GetRow() : "", + cntl.get(), response.get(), options_->enable_debug)) { // rpc error is in cntl or response RPC_STATUS_AND_WARN(status, cntl, response, "Query rpc failed"); return {}; @@ -1995,7 +1996,7 @@ std::shared_ptr SQLClusterRouter::HandleSQLCmd(const h case hybridse::node::kCmdShowGlobalVariables: { std::string db = openmldb::nameserver::INFORMATION_SCHEMA_DB; std::string table = openmldb::nameserver::GLOBAL_VARIABLES; - std::string sql = "select * from " + table; + std::string sql = "select * from " + table + " CONFIG (execute_mode = 'online')"; ::hybridse::sdk::Status status; auto rs = ExecuteSQLParameterized(db, sql, std::shared_ptr(), &status); if (status.code != 0) { @@ -2884,9 +2885,7 @@ std::shared_ptr SQLClusterRouter::ExecuteSQL( } case hybridse::node::kPlanTypeFuncDef: case hybridse::node::kPlanTypeQuery: { - ::hybridse::vm::EngineMode default_mode = (!cluster_sdk_->IsClusterMode() || is_online_mode) - ? ::hybridse::vm::EngineMode::kBatchMode - : ::hybridse::vm::EngineMode::kOffline; + ::hybridse::vm::EngineMode default_mode = GetDefaultEngineMode(); // execute_mode in query config clause takes precedence auto mode = ::hybridse::vm::Engine::TryDetermineEngineMode(sql, default_mode); if (mode != ::hybridse::vm::EngineMode::kOffline) { @@ -3191,10 +3190,27 @@ std::shared_ptr SQLClusterRouter::ExecuteOfflineQuery( } } -bool SQLClusterRouter::IsOnlineMode() { +::hybridse::vm::EngineMode SQLClusterRouter::GetDefaultEngineMode() const { std::lock_guard<::openmldb::base::SpinMutex> lock(mu_); auto it = session_variables_.find("execute_mode"); - if (it != session_variables_.end() && it->second == "online") { + if (it != session_variables_.end()) { + // 1. infer from system variable + auto m = hybridse::vm::UnparseEngineMode(it->second).value_or(hybridse::vm::EngineMode::kBatchMode); + + // 2. standalone mode do not have offline + if (!cluster_sdk_->IsClusterMode() && m == hybridse::vm::kOffline) { + return hybridse::vm::kBatchMode; + } + return m; + } + + return hybridse::vm::EngineMode::kBatchMode; +} + +bool SQLClusterRouter::IsOnlineMode() const { + std::lock_guard<::openmldb::base::SpinMutex> lock(mu_); + auto it = session_variables_.find("execute_mode"); + if (it != session_variables_.end() && (it->second == "online" || it->second == "request")) { return true; } return false; @@ -3273,8 +3289,9 @@ ::hybridse::sdk::Status SQLClusterRouter::SetVariable(hybridse::node::SetPlanNod std::transform(value.begin(), value.end(), value.begin(), ::tolower); // TODO(hw): validation can be simpler if (key == "execute_mode") { - if (value != "online" && value != "offline") { - return {StatusCode::kCmdError, "the value of execute_mode must be online|offline"}; + auto s = sdk::internal::CheckSystemVariableSet(key, value); + if (!s.ok()) { + return {hybridse::common::kCmdError, s.ToString()}; } } else if (key == "enable_trace" || key == "sync_job") { if (value != "true" && value != "false") { diff --git a/src/sdk/sql_cluster_router.h b/src/sdk/sql_cluster_router.h index e917c170a14..e3159c0d1d7 100644 --- a/src/sdk/sql_cluster_router.h +++ b/src/sdk/sql_cluster_router.h @@ -274,7 +274,8 @@ class SQLClusterRouter : public SQLRouter { bool NotifyTableChange() override; - bool IsOnlineMode() override; + ::hybridse::vm::EngineMode GetDefaultEngineMode() const; + bool IsOnlineMode() const override; bool IsEnableTrace(); std::string GetDatabase() override; @@ -454,7 +455,7 @@ class SQLClusterRouter : public SQLRouter { DBSDK* cluster_sdk_; std::map>>> input_lru_cache_; - ::openmldb::base::SpinMutex mu_; + mutable ::openmldb::base::SpinMutex mu_; ::openmldb::base::Random rand_; std::atomic insert_memory_usage_limit_ = 0; // [0-100], the default value 0 means unlimited }; diff --git a/src/sdk/sql_router.h b/src/sdk/sql_router.h index f68d7d39a1c..55fd72b6b5e 100644 --- a/src/sdk/sql_router.h +++ b/src/sdk/sql_router.h @@ -220,7 +220,7 @@ class SQLRouter { virtual bool NotifyTableChange() = 0; - virtual bool IsOnlineMode() = 0; + virtual bool IsOnlineMode() const = 0; virtual std::string GetDatabase() = 0; diff --git a/src/sdk/sql_router_sdk.i b/src/sdk/sql_router_sdk.i index 15ea2b8e7c4..0186550e0a9 100644 --- a/src/sdk/sql_router_sdk.i +++ b/src/sdk/sql_router_sdk.i @@ -80,6 +80,7 @@ #include "sdk/sql_insert_row.h" #include "sdk/sql_delete_row.h" #include "sdk/table_reader.h" +#include "sdk/internal/system_variable.h" using hybridse::sdk::Schema; using hybridse::sdk::ColumnTypes; diff --git a/src/tablet/tablet_impl.cc b/src/tablet/tablet_impl.cc index 2f7544f2847..1545b96c9d4 100644 --- a/src/tablet/tablet_impl.cc +++ b/src/tablet/tablet_impl.cc @@ -1691,6 +1691,7 @@ void TabletImpl::ProcessQuery(bool is_sub, RpcController* ctrl, const openmldb:: auto mode = hybridse::vm::Engine::TryDetermineEngineMode(request->sql(), default_mode); ::hybridse::base::Status status; + // FIXME(someone): it does not handles batchrequest if (mode == hybridse::vm::EngineMode::kBatchMode) { // convert repeated openmldb:type::DataType into hybridse::codec::Schema hybridse::codec::Schema parameter_schema; @@ -5426,7 +5427,12 @@ void TabletImpl::RunRequestQuery(RpcController* ctrl, const openmldb::api::Query } if (ret != 0) { response.set_code(::openmldb::base::kSQLRunError); - response.set_msg("fail to run sql"); + if (ret == hybridse::common::StatusCode::kRunSessionError) { + // special handling + response.set_msg("request SQL requires a non-empty request row, but empty row received"); + } else { + response.set_msg("fail to run sql"); + } return; } else if (row.GetRowPtrCnt() != 1) { response.set_code(::openmldb::base::kSQLRunError); From cf86f04f6f64c98d6843179dcd872eeed6ddb251 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Wed, 26 Jun 2024 12:52:57 +0800 Subject: [PATCH 4/9] feat(udf): array_combine & array_join (#3945) * feat(udf): array_combine * feat(udf): new functions - array_combine - array_join * feat: casting arrays to array for array_combine WIP, string allocation need fix * fix: array_combine with non-string types * feat(array_combine): handle null inputs * fix(array_combine): behavior tweaks - use empty string if delimiter is null - restrict to array_combine(string, array ...) --- cases/query/udf_query.yaml | 92 +++++++++++++++++ hybridse/src/base/cartesian_product.cc | 62 +++++++++++ hybridse/src/base/cartesian_product.h | 32 ++++++ hybridse/src/codegen/array_ir_builder.cc | 115 +++++++++++++++++++++ hybridse/src/codegen/array_ir_builder.h | 14 ++- hybridse/src/codegen/ir_base_builder.cc | 20 +++- hybridse/src/codegen/string_ir_builder.cc | 12 +++ hybridse/src/codegen/string_ir_builder.h | 4 + hybridse/src/codegen/struct_ir_builder.cc | 96 ++++++++++++++++- hybridse/src/codegen/struct_ir_builder.h | 13 +++ hybridse/src/udf/default_defs/array_def.cc | 78 ++++++++++++++ hybridse/src/udf/udf.cc | 55 ++++++++++ hybridse/src/udf/udf.h | 3 +- hybridse/src/vm/jit_wrapper.cc | 7 ++ 14 files changed, 595 insertions(+), 8 deletions(-) create mode 100644 hybridse/src/base/cartesian_product.cc create mode 100644 hybridse/src/base/cartesian_product.h diff --git a/cases/query/udf_query.yaml b/cases/query/udf_query.yaml index ee9cad2d667..bc0cbe786fd 100644 --- a/cases/query/udf_query.yaml +++ b/cases/query/udf_query.yaml @@ -573,6 +573,98 @@ cases: - c1 bool data: | true, false + - id: array_join + mode: request-unsupport + sql: | + select + array_join(["1", "2"], ",") c1, + array_join(["1", "2"], "") c2, + array_join(["1", "2"], cast(null as string)) c3, + array_join(["1", NULL, "4", "5", NULL], "-") c4, + array_join(array[], ",") as c5 + expect: + columns: + - c1 string + - c2 string + - c3 string + - c4 string + - c5 string + rows: + - ["1,2", "12", "12", "1-4-5", ""] + - id: array_combine + mode: request-unsupport + sql: | + select + array_join(array_combine("-", ["1", "2"], ["3", "4"]), ",") c0, + expect: + columns: + - c0 string + rows: + - ["1-3,1-4,2-3,2-4"] + + - id: array_combine_2 + desc: array_combine casting array to array first + mode: request-unsupport + sql: | + select + array_join(array_combine("-", [1, 2], [3, 4]), ",") c0, + array_join(array_combine("-", [1, 2], array[3], ["5", "6"]), ",") c1, + array_join(array_combine("|", ["1"], [timestamp(1717171200000), timestamp("2024-06-02 12:00:00")]), ",") c2, + array_join(array_combine("|", ["1"]), ",") c3, + expect: + columns: + - c0 string + - c1 string + - c2 string + - c3 string + rows: + - ["1-3,1-4,2-3,2-4", "1-3-5,1-3-6,2-3-5,2-3-6", "1|2024-06-01 00:00:00,1|2024-06-02 12:00:00", "1"] + - id: array_combine_3 + desc: null values skipped + mode: request-unsupport + sql: | + select + array_join(array_combine("-", [1, NULL], [3, 4]), ",") c0, + array_join(array_combine("-", ARRAY[NULL], ["9", "8"]), ",") c1, + array_join(array_combine(string(NULL), ARRAY[1], ["9", "8"]), ",") c2, + expect: + columns: + - c0 string + - c1 string + - c2 string + rows: + - ["1-3,1-4", "", "19,18"] + - id: array_combine_4 + desc: construct array from table + mode: request-unsupport + inputs: + - name: t1 + columns: ["col1:int32", "std_ts:timestamp", "col2:string"] + indexs: ["index1:col1:std_ts"] + rows: + - [1, 1590115420001, "foo"] + - [2, 1590115420001, "bar"] + sql: | + select + col1, + array_join(array_combine("-", [col1, 10], [col2, "c2"]), ",") c0, + from t1 + expect: + columns: + - col1 int32 + - c0 string + rows: + - [1, "1-foo,1-c2,10-foo,10-c2"] + - [2, "2-bar,2-c2,10-bar,10-c2"] + - id: array_combine_err1 + mode: request-unsupport + sql: | + select + array_join(array_combine("-"), ",") c0, + expect: + success: false + msg: | + Fail to resolve expression: array_join(array_combine(-), ,) # ================================================================ # Map data type diff --git a/hybridse/src/base/cartesian_product.cc b/hybridse/src/base/cartesian_product.cc new file mode 100644 index 00000000000..f06d28edcdf --- /dev/null +++ b/hybridse/src/base/cartesian_product.cc @@ -0,0 +1,62 @@ +/** + * Copyright (c) 2024 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "base/cartesian_product.h" + +#include + +#include "absl/types/span.h" + +namespace hybridse { +namespace base { + +static auto cartesian_product(const std::vector>& lists) { + std::vector> result; + if (std::find_if(std::begin(lists), std::end(lists), [](auto e) -> bool { return e.size() == 0; }) != + std::end(lists)) { + return result; + } + for (auto& e : lists[0]) { + result.push_back({e}); + } + for (size_t i = 1; i < lists.size(); ++i) { + std::vector> temp; + for (auto& e : result) { + for (auto f : lists[i]) { + auto e_tmp = e; + e_tmp.push_back(f); + temp.push_back(e_tmp); + } + } + result = temp; + } + return result; +} + +std::vector> cartesian_product(absl::Span vec) { + std::vector> input; + for (auto& v : vec) { + std::vector seq(v, 0); + for (int i = 0; i < v; ++i) { + seq[i] = i; + } + input.push_back(seq); + } + return cartesian_product(input); +} + +} // namespace base +} // namespace hybridse diff --git a/hybridse/src/base/cartesian_product.h b/hybridse/src/base/cartesian_product.h new file mode 100644 index 00000000000..54d4191aba1 --- /dev/null +++ b/hybridse/src/base/cartesian_product.h @@ -0,0 +1,32 @@ +/** + * Copyright (c) 2024 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HYBRIDSE_SRC_BASE_CARTESIAN_PRODUCT_H_ +#define HYBRIDSE_SRC_BASE_CARTESIAN_PRODUCT_H_ + +#include + +#include "absl/types/span.h" + +namespace hybridse { +namespace base { + +std::vector> cartesian_product(absl::Span vec); + +} // namespace base +} // namespace hybridse + +#endif // HYBRIDSE_SRC_BASE_CARTESIAN_PRODUCT_H_ diff --git a/hybridse/src/codegen/array_ir_builder.cc b/hybridse/src/codegen/array_ir_builder.cc index 545ccb7555b..8e4a6005800 100644 --- a/hybridse/src/codegen/array_ir_builder.cc +++ b/hybridse/src/codegen/array_ir_builder.cc @@ -18,8 +18,12 @@ #include +#include "absl/strings/substitute.h" +#include "base/fe_status.h" +#include "codegen/cast_expr_ir_builder.h" #include "codegen/context.h" #include "codegen/ir_base_builder.h" +#include "codegen/string_ir_builder.h" namespace hybridse { namespace codegen { @@ -122,5 +126,116 @@ bool ArrayIRBuilder::CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** ou return true; } +absl::StatusOr ArrayIRBuilder::ExtractElement(CodeGenContextBase* ctx, const NativeValue& arr, + const NativeValue& key) const { + return absl::UnimplementedError("array extract element"); +} + +absl::StatusOr ArrayIRBuilder::NumElements(CodeGenContextBase* ctx, llvm::Value* arr) const { + llvm::Value* out = nullptr; + if (!Load(ctx->GetCurrentBlock(), arr, SZ_IDX, &out)) { + return absl::InternalError("codegen: fail to extract array size"); + } + + return out; +} + +absl::StatusOr ArrayIRBuilder::CastToArrayString(CodeGenContextBase* ctx, llvm::Value* src) { + auto sb = StructTypeIRBuilder::CreateStructTypeIRBuilder(ctx->GetModule(), src->getType()); + CHECK_ABSL_STATUSOR(sb); + + ArrayIRBuilder* src_builder = dynamic_cast(sb.value().get()); + if (!src_builder) { + return absl::InvalidArgumentError("input value not a array"); + } + + llvm::Type* src_ele_type = src_builder->element_type_; + if (IsStringPtr(src_ele_type)) { + // already array + return src; + } + + auto fields = src_builder->Load(ctx, src); + CHECK_ABSL_STATUSOR(fields); + llvm::Value* src_raws = fields.value().at(RAW_IDX); + llvm::Value* src_nulls = fields.value().at(NULL_IDX); + llvm::Value* num_elements = fields.value().at(SZ_IDX); + + llvm::Value* casted = nullptr; + if (!CreateDefault(ctx->GetCurrentBlock(), &casted)) { + return absl::InternalError("codegen error: fail to construct default array"); + } + // initialize each element + CHECK_ABSL_STATUS(Initialize(ctx, casted, {num_elements})); + + auto builder = ctx->GetBuilder(); + auto dst_fields = Load(ctx, casted); + CHECK_ABSL_STATUSOR(fields); + auto* raw_array_ptr = dst_fields.value().at(RAW_IDX); + auto* nullables_ptr = dst_fields.value().at(NULL_IDX); + + llvm::Type* idx_type = builder->getInt64Ty(); + llvm::Value* idx = builder->CreateAlloca(idx_type); + builder->CreateStore(builder->getInt64(0), idx); + CHECK_STATUS_TO_ABSL(ctx->CreateWhile( + [&](llvm::Value** cond) -> base::Status { + *cond = builder->CreateICmpSLT(builder->CreateLoad(idx_type, idx), num_elements); + return {}; + }, + [&]() -> base::Status { + llvm::Value* idx_val = builder->CreateLoad(idx_type, idx); + codegen::CastExprIRBuilder cast_builder(ctx->GetCurrentBlock()); + + llvm::Value* src_ele_value = + builder->CreateLoad(src_ele_type, builder->CreateGEP(src_ele_type, src_raws, idx_val)); + llvm::Value* dst_ele = + builder->CreateLoad(element_type_, builder->CreateGEP(element_type_, raw_array_ptr, idx_val)); + + codegen::StringIRBuilder str_builder(ctx->GetModule()); + auto s = str_builder.CastFrom(ctx->GetCurrentBlock(), src_ele_value, dst_ele); + CHECK_TRUE(s.ok(), common::kCodegenError, s.ToString()); + + builder->CreateStore( + builder->CreateLoad(builder->getInt1Ty(), builder->CreateGEP(builder->getInt1Ty(), src_nulls, idx_val)), + builder->CreateGEP(builder->getInt1Ty(), nullables_ptr, idx_val)); + + builder->CreateStore(builder->CreateAdd(idx_val, builder->getInt64(1)), idx); + return {}; + })); + + CHECK_ABSL_STATUS(Set(ctx, casted, {raw_array_ptr, nullables_ptr, num_elements})); + return casted; +} + +absl::Status ArrayIRBuilder::Initialize(CodeGenContextBase* ctx, ::llvm::Value* alloca, + absl::Span args) const { + auto* builder = ctx->GetBuilder(); + StringIRBuilder str_builder(ctx->GetModule()); + auto ele_type = str_builder.GetType(); + if (!alloca->getType()->isPointerTy() || alloca->getType()->getPointerElementType() != struct_type_ || + ele_type->getPointerTo() != element_type_) { + return absl::UnimplementedError(absl::Substitute( + "not able to Initialize array except array, got type $0", GetLlvmObjectString(alloca->getType()))); + } + if (args.size() != 1) { + // require one argument that is array size + return absl::InvalidArgumentError("initialize array requries one argument which is array size"); + } + if (!args[0]->getType()->isIntegerTy()) { + return absl::InvalidArgumentError("array size argument should be integer"); + } + auto sz = args[0]; + if (sz->getType() != builder->getInt64Ty()) { + CastExprIRBuilder cast_builder(ctx->GetCurrentBlock()); + base::Status s; + cast_builder.SafeCastNumber(sz, builder->getInt64Ty(), &sz, s); + CHECK_STATUS_TO_ABSL(s); + } + auto fn = ctx->GetModule()->getOrInsertFunction("hybridse_alloc_array_string", builder->getVoidTy(), + struct_type_->getPointerTo(), builder->getInt64Ty()); + + builder->CreateCall(fn, {alloca, sz}); + return absl::OkStatus(); +} } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/array_ir_builder.h b/hybridse/src/codegen/array_ir_builder.h index a4ab9dd0a6e..9ee522b509c 100644 --- a/hybridse/src/codegen/array_ir_builder.h +++ b/hybridse/src/codegen/array_ir_builder.h @@ -42,11 +42,21 @@ class ArrayIRBuilder : public StructTypeIRBuilder { CHECK_TRUE(false, common::kCodegenError, "casting to array un-implemented"); }; - private: - void InitStructType() override; + absl::StatusOr CastToArrayString(CodeGenContextBase* ctx, llvm::Value* src); + + absl::StatusOr ExtractElement(CodeGenContextBase* ctx, const NativeValue& arr, + const NativeValue& key) const override; + + absl::StatusOr NumElements(CodeGenContextBase* ctx, llvm::Value* arr) const override; bool CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) override; + absl::Status Initialize(CodeGenContextBase* ctx, ::llvm::Value* alloca, + absl::Span args) const override; + + private: + void InitStructType() override; + private: ::llvm::Type* element_type_ = nullptr; }; diff --git a/hybridse/src/codegen/ir_base_builder.cc b/hybridse/src/codegen/ir_base_builder.cc index 72dc7c93de7..f272cf9bda0 100644 --- a/hybridse/src/codegen/ir_base_builder.cc +++ b/hybridse/src/codegen/ir_base_builder.cc @@ -575,12 +575,12 @@ bool GetFullType(node::NodeManager* nm, ::llvm::Type* type, if (type_pointee->isStructTy()) { auto* key_type = type_pointee->getStructElementType(1); const node::TypeNode* key = nullptr; - if (key_type->isPointerTy() && !GetFullType(nm, key_type->getPointerElementType(), &key)) { + if (!key_type->isPointerTy() || !GetFullType(nm, key_type->getPointerElementType(), &key)) { return false; } const node::TypeNode* value = nullptr; auto* value_type = type_pointee->getStructElementType(2); - if (value_type->isPointerTy() && !GetFullType(nm, value_type->getPointerElementType(), &value)) { + if (!value_type->isPointerTy() || !GetFullType(nm, value_type->getPointerElementType(), &value)) { return false; } @@ -590,6 +590,22 @@ bool GetFullType(node::NodeManager* nm, ::llvm::Type* type, } return false; } + case hybridse::node::kArray: { + if (type->isPointerTy()) { + auto type_pointee = type->getPointerElementType(); + if (type_pointee->isStructTy()) { + auto* key_type = type_pointee->getStructElementType(0); + const node::TypeNode* key = nullptr; + if (!key_type->isPointerTy() || !GetFullType(nm, key_type->getPointerElementType(), &key)) { + return false; + } + + *type_node = nm->MakeNode(node::DataType::kArray, key); + return true; + } + } + return false; + } default: { *type_node = nm->MakeTypeNode(base); return true; diff --git a/hybridse/src/codegen/string_ir_builder.cc b/hybridse/src/codegen/string_ir_builder.cc index 083c907fbe4..7e677f3bd3c 100644 --- a/hybridse/src/codegen/string_ir_builder.cc +++ b/hybridse/src/codegen/string_ir_builder.cc @@ -403,5 +403,17 @@ base::Status StringIRBuilder::ConcatWS(::llvm::BasicBlock* block, *output = NativeValue::CreateWithFlag(concat_str, ret_null); return base::Status(); } +absl::Status StringIRBuilder::CastFrom(llvm::BasicBlock* block, llvm::Value* src, llvm::Value* alloca) { + if (IsStringPtr(src->getType())) { + return absl::UnimplementedError("not necessary to cast string to string"); + } + ::llvm::IRBuilder<> builder(block); + ::std::string fn_name = "string." + TypeName(src->getType()); + + auto cast_func = m_->getOrInsertFunction( + fn_name, ::llvm::FunctionType::get(builder.getVoidTy(), {src->getType(), alloca->getType()}, false)); + builder.CreateCall(cast_func, {src, alloca}); + return absl::OkStatus(); +} } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/string_ir_builder.h b/hybridse/src/codegen/string_ir_builder.h index 84f73d2822d..3c6bf986a10 100644 --- a/hybridse/src/codegen/string_ir_builder.h +++ b/hybridse/src/codegen/string_ir_builder.h @@ -37,6 +37,10 @@ class StringIRBuilder : public StructTypeIRBuilder { base::Status CastFrom(::llvm::BasicBlock* block, const NativeValue& src, NativeValue* output) override; base::Status CastFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value** output); + // casting from {in} to string ptr alloca, alloca has allocated already. + // if {in} is string ptr already, it returns error status since generally it's not necessary to call this function + absl::Status CastFrom(llvm::BasicBlock* block, llvm::Value* in, llvm::Value* alloca); + bool NewString(::llvm::BasicBlock* block, ::llvm::Value** output); bool NewString(::llvm::BasicBlock* block, const std::string& str, ::llvm::Value** output); diff --git a/hybridse/src/codegen/struct_ir_builder.cc b/hybridse/src/codegen/struct_ir_builder.cc index 4b00b052c29..d288d05a57d 100644 --- a/hybridse/src/codegen/struct_ir_builder.cc +++ b/hybridse/src/codegen/struct_ir_builder.cc @@ -16,8 +16,11 @@ #include "codegen/struct_ir_builder.h" +#include + #include "absl/status/status.h" #include "absl/strings/substitute.h" +#include "codegen/array_ir_builder.h" #include "codegen/context.h" #include "codegen/date_ir_builder.h" #include "codegen/ir_base_builder.h" @@ -70,6 +73,16 @@ absl::StatusOr> StructTypeIRBuilder::Create } break; } + case node::DataType::kArray: { + assert(ctype->IsArray() && "logic error: not a array type"); + assert(ctype->GetGenericSize() == 1 && "logic error: not a array type"); + ::llvm::Type* ele_type = nullptr; + if (!codegen::GetLlvmType(m, ctype->GetGenericType(0), &ele_type)) { + return absl::InvalidArgumentError( + absl::Substitute("not able to casting array type: $0", GetLlvmObjectString(type))); + } + return std::make_unique(m, ele_type); + } default: { break; } @@ -181,6 +194,11 @@ absl::StatusOr StructTypeIRBuilder::ExtractElement(CodeGenContextBa absl::StrCat("extract element unimplemented for ", GetLlvmObjectString(struct_type_))); } +absl::StatusOr StructTypeIRBuilder::NumElements(CodeGenContextBase* ctx, llvm::Value* arr) const { + return absl::UnimplementedError( + absl::StrCat("element size unimplemented for ", GetLlvmObjectString(struct_type_))); +} + void StructTypeIRBuilder::EnsureOK() const { assert(struct_type_ != nullptr && "filed struct_type_ uninitialized"); // it's a identified type @@ -200,9 +218,9 @@ absl::Status StructTypeIRBuilder::Set(CodeGenContextBase* ctx, ::llvm::Value* st } if (struct_value->getType()->getPointerElementType() != struct_type_) { - return absl::InvalidArgumentError(absl::Substitute("input value has different type, expect $0 but got $1", - GetLlvmObjectString(struct_type_), - GetLlvmObjectString(struct_value->getType()))); + return absl::InvalidArgumentError( + absl::Substitute("input value has different type, expect $0 but got $1", GetLlvmObjectString(struct_type_), + GetLlvmObjectString(struct_value->getType()->getPointerElementType()))); } if (members.size() != struct_type_->getNumElements()) { @@ -229,6 +247,16 @@ absl::StatusOr> StructTypeIRBuilder::Load(CodeGenConte llvm::Value* struct_ptr) const { assert(ctx != nullptr && struct_ptr != nullptr); + if (!IsStructPtr(struct_ptr->getType())) { + return absl::InvalidArgumentError( + absl::StrCat("value not a struct pointer: ", GetLlvmObjectString(struct_ptr->getType()))); + } + if (struct_ptr->getType()->getPointerElementType() != struct_type_) { + return absl::InvalidArgumentError( + absl::Substitute("input value has different type, expect $0 but got $1", GetLlvmObjectString(struct_type_), + GetLlvmObjectString(struct_ptr->getType()->getPointerElementType()))); + } + std::vector res; res.reserve(struct_type_->getNumElements()); @@ -252,5 +280,67 @@ absl::StatusOr CreateSafeNull(::llvm::BasicBlock* block, ::llvm::Ty return NativeValue(nullptr, nullptr, type); } +absl::StatusOr Combine(CodeGenContextBase* ctx, const NativeValue delimiter, + absl::Span args) { + auto builder = ctx->GetBuilder(); + + StringIRBuilder str_builder(ctx->GetModule()); + ArrayIRBuilder arr_builder(ctx->GetModule(), str_builder.GetType()->getPointerTo()); + + llvm::Value* empty_str = nullptr; + if (!str_builder.CreateDefault(ctx->GetCurrentBlock(), &empty_str)) { + return absl::InternalError("codegen error: fail to construct empty string"); + } + llvm::Value* del = builder->CreateSelect(delimiter.GetIsNull(builder), empty_str, delimiter.GetValue(builder)); + + llvm::Type* input_arr_type = arr_builder.GetType()->getPointerTo(); + llvm::Value* empty_arr = nullptr; + if (!arr_builder.CreateDefault(ctx->GetCurrentBlock(), &empty_arr)) { + return absl::InternalError("codegen error: fail to construct empty string of array"); + } + llvm::Value* input_arrays = builder->CreateAlloca(input_arr_type, builder->getInt32(args.size()), "array_data"); + node::NodeManager nm; + std::vector casted_args(args.size()); + for (int i = 0; i < args.size(); ++i) { + const node::TypeNode* tp = nullptr; + if (!GetFullType(&nm, args.at(i).GetType(), &tp)) { + return absl::InternalError("codegen error: fail to get valid type from llvm value"); + } + if (!tp->IsArray() || tp->GetGenericSize() != 1) { + return absl::InternalError("codegen error: arguments to array_combine is not ARRAY"); + } + if (!tp->GetGenericType(0)->IsString()) { + auto s = arr_builder.CastToArrayString(ctx, args.at(i).GetRaw()); + CHECK_ABSL_STATUSOR(s); + casted_args.at(i) = NativeValue::Create(s.value()); + } else { + casted_args.at(i) = args.at(i); + } + + auto safe_str_arr = + builder->CreateSelect(casted_args.at(i).GetIsNull(builder), empty_arr, casted_args.at(i).GetRaw()); + builder->CreateStore(safe_str_arr, builder->CreateGEP(input_arr_type, input_arrays, builder->getInt32(i))); + } + + ::llvm::FunctionCallee array_combine_fn = ctx->GetModule()->getOrInsertFunction( + "hybridse_array_combine", builder->getVoidTy(), str_builder.GetType()->getPointerTo(), builder->getInt32Ty(), + input_arr_type->getPointerTo(), input_arr_type); + assert(array_combine_fn); + + llvm::Value* out = builder->CreateAlloca(arr_builder.GetType()); + builder->CreateCall(array_combine_fn, { + del, // delimiter should ensure non-null + builder->getInt32(args.size()), // num of arrays + input_arrays, // ArrayRef** + out // output string + }); + + return NativeValue::Create(out); +} + +absl::Status StructTypeIRBuilder::Initialize(CodeGenContextBase* ctx, ::llvm::Value* alloca, + absl::Span args) const { + return absl::UnimplementedError(absl::StrCat("Initialize for type ", GetLlvmObjectString(struct_type_))); +} } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/struct_ir_builder.h b/hybridse/src/codegen/struct_ir_builder.h index 8529c2d7848..4c7dba732dc 100644 --- a/hybridse/src/codegen/struct_ir_builder.h +++ b/hybridse/src/codegen/struct_ir_builder.h @@ -58,6 +58,9 @@ class StructTypeIRBuilder : public TypeIRBuilder { virtual absl::StatusOr<::llvm::Value*> ConstructFromRaw(CodeGenContextBase* ctx, absl::Span<::llvm::Value* const> args) const; + virtual absl::Status Initialize(CodeGenContextBase* ctx, ::llvm::Value* alloca, + absl::Span args) const; + // Extract element value from composite data type // 1. extract from array type by index // 2. extract from struct type by field name @@ -65,6 +68,12 @@ class StructTypeIRBuilder : public TypeIRBuilder { virtual absl::StatusOr ExtractElement(CodeGenContextBase* ctx, const NativeValue& arr, const NativeValue& key) const; + // Get size of the elements inside value {arr} + // - if {arr} is array/map, return size of array/map + // - if {arr} is struct, return number of struct fields + // - otherwise report error + virtual absl::StatusOr NumElements(CodeGenContextBase* ctx, llvm::Value* arr) const; + ::llvm::Type* GetType() const; std::string GetTypeDebugString() const; @@ -100,6 +109,10 @@ class StructTypeIRBuilder : public TypeIRBuilder { // returns NativeValue{raw, is_null=true} on success, raw is ensured to be not nullptr absl::StatusOr CreateSafeNull(::llvm::BasicBlock* block, ::llvm::Type* type); +// Do the cartesian product for a list of arrry +// output a array of string, each value is a pair (A1, B2, C3...), as "A1-B2-C3-...", "-" is the delimiter +absl::StatusOr Combine(CodeGenContextBase* ctx, const NativeValue delimiter, + absl::Span args); } // namespace codegen } // namespace hybridse #endif // HYBRIDSE_SRC_CODEGEN_STRUCT_IR_BUILDER_H_ diff --git a/hybridse/src/udf/default_defs/array_def.cc b/hybridse/src/udf/default_defs/array_def.cc index b5c36bc3d7e..5f1bebfaaf6 100644 --- a/hybridse/src/udf/default_defs/array_def.cc +++ b/hybridse/src/udf/default_defs/array_def.cc @@ -15,6 +15,7 @@ */ #include "absl/strings/str_split.h" +#include "codegen/struct_ir_builder.h" #include "udf/default_udf_library.h" #include "udf/udf.h" #include "udf/udf_registry.h" @@ -101,6 +102,36 @@ void SplitString(StringRef* str, StringRef* delimeter, ArrayRef* arra } } +void array_join(ArrayRef* arr, StringRef* del, bool del_null, StringRef* out) { + int sz = 0; + for (int i = 0; i < arr->size; ++i) { + if (!arr->nullables[i]) { + if (!del_null && i > 0) { + sz += del->size_; + } + sz += arr->raw[i]->size_; + } + } + + auto buf = udf::v1::AllocManagedStringBuf(sz); + memset(buf, 0, sz); + + int32_t idx = 0; + for (int i = 0; i < arr->size; ++i) { + if (!arr->nullables[i]) { + if (!del_null && i > 0) { + memcpy(buf + idx, del->data_, del->size_); + idx += del->size_; + } + memcpy(buf + idx, arr->raw[i]->data_, arr->raw[i]->size_); + idx += arr->raw[i]->size_; + } + } + + out->data_ = buf; + out->size_ = sz; +} + // =========================================================== // // UDF Register Entry // =========================================================== // @@ -148,6 +179,53 @@ void DefaultUdfLibrary::InitArrayUdfs() { @endcode @since 0.7.0)"); + RegisterExternal("array_join") + .args, Nullable>(array_join) + .doc(R"( + @brief array_join(array, delimiter) - Concatenates the elements of the given array using the delimiter. Any null value is filtered. + + Example: + + @code{.sql} + select array_join(["1", "2"], "-"); + -- output "1-2" + @endcode + @since 0.9.2)"); + + RegisterCodeGenUdf("array_combine") + .variadic_args>( + [](UdfResolveContext* ctx, const ExprAttrNode& delimit, const std::vector& arg_attrs, + ExprAttrNode* out) -> base::Status { + CHECK_TRUE(!arg_attrs.empty(), common::kCodegenError, "at least one array required by array_combine"); + for (auto & val : arg_attrs) { + CHECK_TRUE(val.type()->IsArray(), common::kCodegenError, "argument to array_combine must be array"); + } + auto nm = ctx->node_manager(); + out->SetType(nm->MakeNode(node::kArray, nm->MakeNode(node::kVarchar))); + out->SetNullable(false); + return {}; + }, + [](codegen::CodeGenContext* ctx, codegen::NativeValue del, const std::vector& args, + const node::ExprAttrNode& return_info, codegen::NativeValue* out) -> base::Status { + auto os = codegen::Combine(ctx, del, args); + CHECK_TRUE(os.ok(), common::kCodegenError, os.status().ToString()); + *out = os.value(); + return {}; + }) + .doc(R"( + @brief array_combine(delimiter, array1, array2, ...) + + return array of strings for input array1, array2, ... doing cartesian product. Each product is joined with + {delimiter} as a string. Empty string used if {delimiter} is null. + + Example: + + @code{.sql} + select array_combine("-", ["1", "2"], ["3", "4"]); + -- output ["1-3", "1-4", "2-3", "2-4"] + @endcode + @since 0.9.2 + )"); } } // namespace udf } // namespace hybridse diff --git a/hybridse/src/udf/udf.cc b/hybridse/src/udf/udf.cc index b32d75d4ac8..e5faf25db1e 100644 --- a/hybridse/src/udf/udf.cc +++ b/hybridse/src/udf/udf.cc @@ -21,11 +21,13 @@ #include #include +#include #include "absl/strings/ascii.h" #include "absl/strings/str_replace.h" #include "absl/time/civil_time.h" #include "absl/time/time.h" +#include "base/cartesian_product.h" #include "base/iterator.h" #include "boost/date_time/gregorian/conversion.hpp" #include "boost/date_time/gregorian/parsers.hpp" @@ -1413,6 +1415,59 @@ void printLog(const char* fmt) { } } +// each variadic arg is ArrayRef* +void array_combine(codec::StringRef *del, int32_t cnt, ArrayRef **data, + ArrayRef *out) { + std::vector arr_szs(cnt, 0); + for (int32_t i = 0; i < cnt; ++i) { + auto arr = data[i]; + arr_szs.at(i) = arr->size; + } + + // cal cartesian products + auto products = hybridse::base::cartesian_product(arr_szs); + + auto real_sz = products.size(); + v1::AllocManagedArray(out, products.size()); + + for (int prod_idx = 0; prod_idx < products.size(); ++prod_idx) { + auto &prod = products.at(prod_idx); + int32_t sz = 0; + for (int i = 0; i < prod.size(); ++i) { + if (!data[i]->nullables[prod.at(i)]) { + // delimiter would be empty string if null + if (i > 0) { + sz += del->size_; + } + sz += data[i]->raw[prod.at(i)]->size_; + } else { + // null exists in current product + // the only option now is to skip + real_sz--; + continue; + } + } + auto buf = v1::AllocManagedStringBuf(sz); + int32_t idx = 0; + for (int i = 0; i < prod.size(); ++i) { + if (!data[i]->nullables[prod.at(i)]) { + if (i > 0 && del->size_ > 0) { + memcpy(buf + idx, del->data_, del->size_); + idx += del->size_; + } + memcpy(buf + idx, data[i]->raw[prod.at(i)]->data_, data[i]->raw[prod.at(i)]->size_); + idx += data[i]->raw[prod.at(i)]->size_; + } + } + + out->nullables[prod_idx] = false; + out->raw[prod_idx]->data_ = buf; + out->raw[prod_idx]->size_ = sz; + } + + out->size = real_sz; +} + } // namespace v1 bool RegisterMethod(UdfLibrary *lib, const std::string &fn_name, hybridse::node::TypeNode *ret, diff --git a/hybridse/src/udf/udf.h b/hybridse/src/udf/udf.h index b7f222433a7..480e4f89f3c 100644 --- a/hybridse/src/udf/udf.h +++ b/hybridse/src/udf/udf.h @@ -520,7 +520,8 @@ void hex(StringRef *str, StringRef *output); void unhex(StringRef *str, StringRef *output, bool* is_null); void printLog(const char* fmt); - +void array_combine(codec::StringRef *del, int32_t cnt, ArrayRef **data, + ArrayRef *out); } // namespace v1 /// \brief register native udf related methods into given UdfLibrary `lib` diff --git a/hybridse/src/vm/jit_wrapper.cc b/hybridse/src/vm/jit_wrapper.cc index d4b4df1b01d..220a7d93085 100644 --- a/hybridse/src/vm/jit_wrapper.cc +++ b/hybridse/src/vm/jit_wrapper.cc @@ -17,6 +17,8 @@ #include #include + +#include "base/cartesian_product.h" #include "glog/logging.h" #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/Orc/CompileUtils.h" @@ -251,6 +253,11 @@ void InitBuiltinJitSymbols(HybridSeJitWrapper* jit) { "fmod", reinterpret_cast( static_cast(&fmod))); jit->AddExternalFunction("fmodf", reinterpret_cast(&fmodf)); + + // cartesian product + jit->AddExternalFunction("hybridse_array_combine", reinterpret_cast(&hybridse::udf::v1::array_combine)); + jit->AddExternalFunction("hybridse_alloc_array_string", + reinterpret_cast(&hybridse::udf::v1::AllocManagedArray)); } } // namespace vm From e3da2a6bbfa45792ae925e27667dd38af8cd3b99 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Wed, 26 Jun 2024 16:05:58 +0800 Subject: [PATCH 5/9] feat: support batchrequest in ProcessQuery (#3938) --- src/tablet/tablet_impl.cc | 239 ++++++++++++++++++++++++-------------- 1 file changed, 151 insertions(+), 88 deletions(-) diff --git a/src/tablet/tablet_impl.cc b/src/tablet/tablet_impl.cc index 1545b96c9d4..230b5c46a09 100644 --- a/src/tablet/tablet_impl.cc +++ b/src/tablet/tablet_impl.cc @@ -20,6 +20,7 @@ #include #include #include +#include "vm/sql_compiler.h" #ifdef DISALLOW_COPY_AND_ASSIGN #undef DISALLOW_COPY_AND_ASSIGN #endif @@ -1691,98 +1692,127 @@ void TabletImpl::ProcessQuery(bool is_sub, RpcController* ctrl, const openmldb:: auto mode = hybridse::vm::Engine::TryDetermineEngineMode(request->sql(), default_mode); ::hybridse::base::Status status; - // FIXME(someone): it does not handles batchrequest - if (mode == hybridse::vm::EngineMode::kBatchMode) { - // convert repeated openmldb:type::DataType into hybridse::codec::Schema - hybridse::codec::Schema parameter_schema; - for (int i = 0; i < request->parameter_types().size(); i++) { - auto column = parameter_schema.Add(); - hybridse::type::Type hybridse_type; - - if (!openmldb::schema::SchemaAdapter::ConvertType(request->parameter_types(i), &hybridse_type)) { - response->set_msg("Invalid parameter type: " + - openmldb::type::DataType_Name(request->parameter_types(i))); - response->set_code(::openmldb::base::kSQLCompileError); - return; + switch (mode) { + case hybridse::vm::EngineMode::kBatchMode: { + // convert repeated openmldb:type::DataType into hybridse::codec::Schema + hybridse::codec::Schema parameter_schema; + for (int i = 0; i < request->parameter_types().size(); i++) { + auto column = parameter_schema.Add(); + hybridse::type::Type hybridse_type; + + if (!openmldb::schema::SchemaAdapter::ConvertType(request->parameter_types(i), &hybridse_type)) { + response->set_msg("Invalid parameter type: " + + openmldb::type::DataType_Name(request->parameter_types(i))); + response->set_code(::openmldb::base::kSQLCompileError); + return; + } + column->set_type(hybridse_type); } - column->set_type(hybridse_type); - } - ::hybridse::vm::BatchRunSession session; - if (request->is_debug()) { - session.EnableDebug(); - } - session.SetParameterSchema(parameter_schema); - { - bool ok = engine_->Get(request->sql(), request->db(), session, status); - if (!ok) { - response->set_msg(status.msg); - response->set_code(::openmldb::base::kSQLCompileError); - DLOG(WARNING) << "fail to compile sql " << request->sql() << ", message: " << status.msg; - return; + ::hybridse::vm::BatchRunSession session; + if (request->is_debug()) { + session.EnableDebug(); + } + session.SetParameterSchema(parameter_schema); + { + bool ok = engine_->Get(request->sql(), request->db(), session, status); + if (!ok) { + response->set_msg(status.msg); + response->set_code(::openmldb::base::kSQLCompileError); + DLOG(WARNING) << "fail to compile sql " << request->sql() << ", message: " << status.msg; + return; + } } - } - ::hybridse::codec::Row parameter_row; - auto& request_buf = static_cast(ctrl)->request_attachment(); - if (request->parameter_row_size() > 0 && - !codec::DecodeRpcRow(request_buf, 0, request->parameter_row_size(), request->parameter_row_slices(), - ¶meter_row)) { - response->set_code(::openmldb::base::kSQLRunError); - response->set_msg("fail to decode parameter row"); - return; - } - std::vector<::hybridse::codec::Row> output_rows; - int32_t run_ret = session.Run(parameter_row, output_rows); - if (run_ret != 0) { - response->set_msg(status.msg); - response->set_code(::openmldb::base::kSQLRunError); - DLOG(WARNING) << "fail to run sql: " << request->sql(); - return; - } - uint32_t byte_size = 0; - uint32_t count = 0; - for (auto& output_row : output_rows) { - if (FLAGS_scan_max_bytes_size > 0 && byte_size > FLAGS_scan_max_bytes_size) { - LOG(WARNING) << "reach the max byte size " << FLAGS_scan_max_bytes_size << " truncate result"; - response->set_schema(session.GetEncodedSchema()); - response->set_byte_size(byte_size); - response->set_count(count); - response->set_code(::openmldb::base::kOk); + ::hybridse::codec::Row parameter_row; + auto& request_buf = static_cast(ctrl)->request_attachment(); + if (request->parameter_row_size() > 0 && + !codec::DecodeRpcRow(request_buf, 0, request->parameter_row_size(), request->parameter_row_slices(), + ¶meter_row)) { + response->set_code(::openmldb::base::kSQLRunError); + response->set_msg("fail to decode parameter row"); + return; + } + std::vector<::hybridse::codec::Row> output_rows; + int32_t run_ret = session.Run(parameter_row, output_rows); + if (run_ret != 0) { + response->set_msg(status.msg); + response->set_code(::openmldb::base::kSQLRunError); + DLOG(WARNING) << "fail to run sql: " << request->sql(); return; } - byte_size += output_row.size(); - buf->append(reinterpret_cast(output_row.buf()), output_row.size()); - count += 1; + uint32_t byte_size = 0; + uint32_t count = 0; + for (auto& output_row : output_rows) { + if (FLAGS_scan_max_bytes_size > 0 && byte_size > FLAGS_scan_max_bytes_size) { + LOG(WARNING) << "reach the max byte size " << FLAGS_scan_max_bytes_size << " truncate result"; + response->set_schema(session.GetEncodedSchema()); + response->set_byte_size(byte_size); + response->set_count(count); + response->set_code(::openmldb::base::kOk); + return; + } + byte_size += output_row.size(); + buf->append(reinterpret_cast(output_row.buf()), output_row.size()); + count += 1; + } + response->set_schema(session.GetEncodedSchema()); + response->set_byte_size(byte_size); + response->set_count(count); + response->set_code(::openmldb::base::kOk); + DLOG(INFO) << "handle batch sql " << request->sql() << " with record cnt " << count << " byte size " + << byte_size; + break; } - response->set_schema(session.GetEncodedSchema()); - response->set_byte_size(byte_size); - response->set_count(count); - response->set_code(::openmldb::base::kOk); - DLOG(INFO) << "handle batch sql " << request->sql() << " with record cnt " << count << " byte size " - << byte_size; - } else { - ::hybridse::vm::RequestRunSession session; - if (request->is_debug()) { - session.EnableDebug(); - } - if (request->is_procedure()) { - const std::string& db_name = request->db(); - const std::string& sp_name = request->sp_name(); - std::shared_ptr request_compile_info; - { - hybridse::base::Status status; - request_compile_info = sp_cache_->GetRequestInfo(db_name, sp_name, status); - if (!status.isOK()) { - response->set_code(::openmldb::base::ReturnCode::kProcedureNotFound); + case hybridse::vm::kRequestMode: { + ::hybridse::vm::RequestRunSession session; + if (request->is_debug()) { + session.EnableDebug(); + } + if (request->is_procedure()) { + const std::string& db_name = request->db(); + const std::string& sp_name = request->sp_name(); + std::shared_ptr request_compile_info; + { + hybridse::base::Status status; + request_compile_info = sp_cache_->GetRequestInfo(db_name, sp_name, status); + if (!status.isOK()) { + response->set_code(::openmldb::base::ReturnCode::kProcedureNotFound); + response->set_msg(status.msg); + PDLOG(WARNING, status.msg.c_str()); + return; + } + } + session.SetCompileInfo(request_compile_info); + session.SetSpName(sp_name); + RunRequestQuery(ctrl, *request, session, *response, *buf); + } else { + bool ok = engine_->Get(request->sql(), request->db(), session, status); + if (!ok || session.GetCompileInfo() == nullptr) { response->set_msg(status.msg); - PDLOG(WARNING, status.msg.c_str()); + response->set_code(::openmldb::base::kSQLCompileError); + DLOG(WARNING) << "fail to compile sql in request mode:\n" << request->sql(); return; } + RunRequestQuery(ctrl, *request, session, *response, *buf); + } + const std::string& sql = session.GetCompileInfo()->GetSql(); + if (response->code() != ::openmldb::base::kOk) { + DLOG(WARNING) << "fail to run sql " << sql << " error msg: " << response->msg(); + } else { + DLOG(INFO) << "handle request sql " << sql; + } + break; + } + case hybridse::vm::kBatchRequestMode: { + // we support a simplified batch request query here + // not procedure + // no parameter input or bachrequst row + // batchrequest row must specified in CONFIG (values = ...) + ::hybridse::base::Status status; + ::hybridse::vm::BatchRequestRunSession session; + if (request->is_debug()) { + session.EnableDebug(); } - session.SetCompileInfo(request_compile_info); - session.SetSpName(sp_name); - RunRequestQuery(ctrl, *request, session, *response, *buf); - } else { bool ok = engine_->Get(request->sql(), request->db(), session, status); if (!ok || session.GetCompileInfo() == nullptr) { response->set_msg(status.msg); @@ -1790,13 +1820,46 @@ void TabletImpl::ProcessQuery(bool is_sub, RpcController* ctrl, const openmldb:: DLOG(WARNING) << "fail to compile sql in request mode:\n" << request->sql(); return; } - RunRequestQuery(ctrl, *request, session, *response, *buf); + auto info = std::dynamic_pointer_cast(session.GetCompileInfo()); + if (info && info->get_sql_context().request_rows.empty()) { + response->set_msg("batch request values must specified in SQL CONFIG (values = [...])"); + response->set_code(::openmldb::base::kSQLCompileError); + return; + } + std::vector<::hybridse::codec::Row> output_rows; + std::vector<::hybridse::codec::Row> empty_inputs; + int32_t run_ret = session.Run(empty_inputs, output_rows); + if (run_ret != 0) { + response->set_msg(status.msg); + response->set_code(::openmldb::base::kSQLRunError); + DLOG(WARNING) << "fail to run batchrequest sql: " << request->sql(); + return; + } + uint32_t byte_size = 0; + uint32_t count = 0; + for (auto& output_row : output_rows) { + if (FLAGS_scan_max_bytes_size > 0 && byte_size > FLAGS_scan_max_bytes_size) { + LOG(WARNING) << "reach the max byte size " << FLAGS_scan_max_bytes_size << " truncate result"; + response->set_schema(session.GetEncodedSchema()); + response->set_byte_size(byte_size); + response->set_count(count); + response->set_code(::openmldb::base::kOk); + return; + } + byte_size += output_row.size(); + buf->append(reinterpret_cast(output_row.buf()), output_row.size()); + count += 1; + } + response->set_schema(session.GetEncodedSchema()); + response->set_byte_size(byte_size); + response->set_count(count); + response->set_code(::openmldb::base::kOk); + break; } - const std::string& sql = session.GetCompileInfo()->GetSql(); - if (response->code() != ::openmldb::base::kOk) { - DLOG(WARNING) << "fail to run sql " << sql << " error msg: " << response->msg(); - } else { - DLOG(INFO) << "handle request sql " << sql; + default: { + response->set_msg("un-implemented execute_mode: " + hybridse::vm::EngineModeName(mode)); + response->set_code(::openmldb::base::kSQLCompileError); + break; } } } From 2a739528e2af6c2702590be219c25beefac10f9c Mon Sep 17 00:00:00 2001 From: oh2024 <162292688+oh2024@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:50:58 +0800 Subject: [PATCH 6/9] feat: user authz (#3941) * feat: change user table to match mysql * feat: support user authz * fix: cean up created users --- hybridse/include/node/node_enum.h | 4 + hybridse/include/node/plan_node.h | 61 ++++++++++++ hybridse/include/node/sql_node.h | 58 +++++++++++ hybridse/src/plan/planner.cc | 16 +++ hybridse/src/planv2/ast_node_converter.cc | 90 +++++++++++++++++ hybridse/src/planv2/ast_node_converter.h | 6 ++ src/auth/user_access_manager.cc | 41 ++++++-- src/auth/user_access_manager.h | 10 +- src/base/status.h | 3 +- src/client/ns_client.cc | 29 ++++++ src/client/ns_client.h | 5 + src/cmd/sql_cmd_test.cc | 116 +++++++++++++++++++++- src/nameserver/name_server_impl.cc | 72 +++++++++++--- src/nameserver/name_server_impl.h | 5 +- src/nameserver/system_table.h | 18 ++-- src/proto/name_server.proto | 19 ++++ src/sdk/sql_cluster_router.cc | 24 +++++ 17 files changed, 543 insertions(+), 34 deletions(-) diff --git a/hybridse/include/node/node_enum.h b/hybridse/include/node/node_enum.h index 7b189aa6aac..eea2bd9a953 100644 --- a/hybridse/include/node/node_enum.h +++ b/hybridse/include/node/node_enum.h @@ -98,6 +98,8 @@ enum SqlNodeType { kColumnSchema, kCreateUserStmt, kAlterUserStmt, + kGrantStmt, + kRevokeStmt, kCallStmt, kSqlNodeTypeLast, // debug type kVariadicUdfDef, @@ -347,6 +349,8 @@ enum PlanType { kPlanTypeShow, kPlanTypeCreateUser, kPlanTypeAlterUser, + kPlanTypeGrant, + kPlanTypeRevoke, kPlanTypeCallStmt, kUnknowPlan = -1, }; diff --git a/hybridse/include/node/plan_node.h b/hybridse/include/node/plan_node.h index ec82b6a586f..0e5683c4702 100644 --- a/hybridse/include/node/plan_node.h +++ b/hybridse/include/node/plan_node.h @@ -739,6 +739,67 @@ class CreateUserPlanNode : public LeafPlanNode { const std::shared_ptr options_; }; +class GrantPlanNode : public LeafPlanNode { + public: + explicit GrantPlanNode(std::optional target_type, std::string database, std::string target, + std::vector privileges, bool is_all_privileges, + std::vector grantees, bool with_grant_option) + : LeafPlanNode(kPlanTypeGrant), + target_type_(target_type), + database_(database), + target_(target), + privileges_(privileges), + is_all_privileges_(is_all_privileges), + grantees_(grantees), + with_grant_option_(with_grant_option) {} + ~GrantPlanNode() = default; + const std::vector Privileges() const { return privileges_; } + const std::vector Grantees() const { return grantees_; } + const std::string Database() const { return database_; } + const std::string Target() const { return target_; } + const std::optional TargetType() const { return target_type_; } + const bool IsAllPrivileges() const { return is_all_privileges_; } + const bool WithGrantOption() const { return with_grant_option_; } + + private: + std::optional target_type_; + std::string database_; + std::string target_; + std::vector privileges_; + bool is_all_privileges_; + std::vector grantees_; + bool with_grant_option_; +}; + +class RevokePlanNode : public LeafPlanNode { + public: + explicit RevokePlanNode(std::optional target_type, std::string database, std::string target, + std::vector privileges, bool is_all_privileges, + std::vector grantees) + : LeafPlanNode(kPlanTypeRevoke), + target_type_(target_type), + database_(database), + target_(target), + privileges_(privileges), + is_all_privileges_(is_all_privileges), + grantees_(grantees) {} + ~RevokePlanNode() = default; + const std::vector Privileges() const { return privileges_; } + const std::vector Grantees() const { return grantees_; } + const std::string Database() const { return database_; } + const std::string Target() const { return target_; } + const std::optional TargetType() const { return target_type_; } + const bool IsAllPrivileges() const { return is_all_privileges_; } + + private: + std::optional target_type_; + std::string database_; + std::string target_; + std::vector privileges_; + bool is_all_privileges_; + std::vector grantees_; +}; + class AlterUserPlanNode : public LeafPlanNode { public: explicit AlterUserPlanNode(const std::string& name, bool if_exists, std::shared_ptr options) diff --git a/hybridse/include/node/sql_node.h b/hybridse/include/node/sql_node.h index 96ea7a94163..52542426c2a 100644 --- a/hybridse/include/node/sql_node.h +++ b/hybridse/include/node/sql_node.h @@ -2421,6 +2421,64 @@ class AlterUserNode : public SqlNode { const std::shared_ptr options_; }; +class GrantNode : public SqlNode { + public: + explicit GrantNode(std::optional target_type, std::string database, std::string target, + std::vector privileges, bool is_all_privileges, std::vector grantees, + bool with_grant_option) + : SqlNode(kGrantStmt, 0, 0), + target_type_(target_type), + database_(database), + target_(target), + privileges_(privileges), + is_all_privileges_(is_all_privileges), + grantees_(grantees), + with_grant_option_(with_grant_option) {} + const std::vector Privileges() const { return privileges_; } + const std::vector Grantees() const { return grantees_; } + const std::string Database() const { return database_; } + const std::string Target() const { return target_; } + const std::optional TargetType() const { return target_type_; } + const bool IsAllPrivileges() const { return is_all_privileges_; } + const bool WithGrantOption() const { return with_grant_option_; } + + private: + std::optional target_type_; + std::string database_; + std::string target_; + std::vector privileges_; + bool is_all_privileges_; + std::vector grantees_; + bool with_grant_option_; +}; + +class RevokeNode : public SqlNode { + public: + explicit RevokeNode(std::optional target_type, std::string database, std::string target, + std::vector privileges, bool is_all_privileges, std::vector grantees) + : SqlNode(kRevokeStmt, 0, 0), + target_type_(target_type), + database_(database), + target_(target), + privileges_(privileges), + is_all_privileges_(is_all_privileges), + grantees_(grantees) {} + const std::vector Privileges() const { return privileges_; } + const std::vector Grantees() const { return grantees_; } + const std::string Database() const { return database_; } + const std::string Target() const { return target_; } + const std::optional TargetType() const { return target_type_; } + const bool IsAllPrivileges() const { return is_all_privileges_; } + + private: + std::optional target_type_; + std::string database_; + std::string target_; + std::vector privileges_; + bool is_all_privileges_; + std::vector grantees_; +}; + class ExplainNode : public SqlNode { public: explicit ExplainNode(const QueryNode *query, node::ExplainType explain_type) diff --git a/hybridse/src/plan/planner.cc b/hybridse/src/plan/planner.cc index b2a57b4128c..3a3984c9b16 100644 --- a/hybridse/src/plan/planner.cc +++ b/hybridse/src/plan/planner.cc @@ -768,6 +768,22 @@ base::Status SimplePlanner::CreatePlanTree(const NodePointVector &parser_trees, plan_trees.push_back(create_user_plan_node); break; } + case ::hybridse::node::kGrantStmt: { + auto node = dynamic_cast(parser_tree); + auto grant_plan_node = node_manager_->MakeNode( + node->TargetType(), node->Database(), node->Target(), node->Privileges(), node->IsAllPrivileges(), + node->Grantees(), node->WithGrantOption()); + plan_trees.push_back(grant_plan_node); + break; + } + case ::hybridse::node::kRevokeStmt: { + auto node = dynamic_cast(parser_tree); + auto revoke_plan_node = node_manager_->MakeNode( + node->TargetType(), node->Database(), node->Target(), node->Privileges(), node->IsAllPrivileges(), + node->Grantees()); + plan_trees.push_back(revoke_plan_node); + break; + } case ::hybridse::node::kAlterUserStmt: { auto node = dynamic_cast(parser_tree); auto alter_user_plan_node = node_manager_->MakeNode(node->Name(), diff --git a/hybridse/src/planv2/ast_node_converter.cc b/hybridse/src/planv2/ast_node_converter.cc index a8453e1221c..23e56924ae2 100644 --- a/hybridse/src/planv2/ast_node_converter.cc +++ b/hybridse/src/planv2/ast_node_converter.cc @@ -24,6 +24,7 @@ #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/types/span.h" +#include "ast_node_converter.h" #include "base/fe_status.h" #include "node/sql_node.h" #include "udf/udf.h" @@ -725,6 +726,20 @@ base::Status ConvertStatement(const zetasql::ASTStatement* statement, node::Node *output = create_user_node; break; } + case zetasql::AST_GRANT_STATEMENT: { + const zetasql::ASTGrantStatement* grant_stmt = statement->GetAsOrNull(); + node::GrantNode* grant_node = nullptr; + CHECK_STATUS(ConvertGrantStatement(grant_stmt, node_manager, &grant_node)) + *output = grant_node; + break; + } + case zetasql::AST_REVOKE_STATEMENT: { + const zetasql::ASTRevokeStatement* revoke_stmt = statement->GetAsOrNull(); + node::RevokeNode* revoke_node = nullptr; + CHECK_STATUS(ConvertRevokeStatement(revoke_stmt, node_manager, &revoke_node)) + *output = revoke_node; + break; + } case zetasql::AST_ALTER_USER_STATEMENT: { const zetasql::ASTAlterUserStatement* alter_user_stmt = statement->GetAsOrNull(); @@ -2133,6 +2148,81 @@ base::Status ConvertAlterUserStatement(const zetasql::ASTAlterUserStatement* roo return base::Status::OK(); } +base::Status ConvertGrantStatement(const zetasql::ASTGrantStatement* root, node::NodeManager* node_manager, + node::GrantNode** output) { + CHECK_TRUE(root != nullptr, common::kSqlAstError, "not an ASTGrantStatement"); + std::vector target_path; + CHECK_STATUS(AstPathExpressionToStringList(root->target_path(), target_path)); + std::optional target_type = std::nullopt; + if (root->target_type() != nullptr) { + target_type = root->target_type()->GetAsString(); + } + + std::vector privileges; + std::vector grantees; + for (auto privilege : root->privileges()->privileges()) { + if (privilege == nullptr) { + continue; + } + + auto privilege_action = privilege->privilege_action(); + if (privilege_action != nullptr) { + privileges.push_back(privilege_action->GetAsString()); + } + } + + for (auto grantee : root->grantee_list()->grantee_list()) { + if (grantee == nullptr) { + continue; + } + + std::string grantee_str; + CHECK_STATUS(AstStringLiteralToString(grantee, &grantee_str)); + grantees.push_back(grantee_str); + } + *output = node_manager->MakeNode(target_type, target_path.at(0), target_path.at(1), privileges, + root->privileges()->is_all_privileges(), grantees, + root->with_grant_option()); + return base::Status::OK(); +} + +base::Status ConvertRevokeStatement(const zetasql::ASTRevokeStatement* root, node::NodeManager* node_manager, + node::RevokeNode** output) { + CHECK_TRUE(root != nullptr, common::kSqlAstError, "not an ASTRevokeStatement"); + std::vector target_path; + CHECK_STATUS(AstPathExpressionToStringList(root->target_path(), target_path)); + std::optional target_type = std::nullopt; + if (root->target_type() != nullptr) { + target_type = root->target_type()->GetAsString(); + } + + std::vector privileges; + std::vector grantees; + for (auto privilege : root->privileges()->privileges()) { + if (privilege == nullptr) { + continue; + } + + auto privilege_action = privilege->privilege_action(); + if (privilege_action != nullptr) { + privileges.push_back(privilege_action->GetAsString()); + } + } + + for (auto grantee : root->grantee_list()->grantee_list()) { + if (grantee == nullptr) { + continue; + } + + std::string grantee_str; + CHECK_STATUS(AstStringLiteralToString(grantee, &grantee_str)); + grantees.push_back(grantee_str); + } + *output = node_manager->MakeNode(target_type, target_path.at(0), target_path.at(1), privileges, + root->privileges()->is_all_privileges(), grantees); + return base::Status::OK(); +} + base::Status ConvertCreateIndexStatement(const zetasql::ASTCreateIndexStatement* root, node::NodeManager* node_manager, node::CreateIndexNode** output) { CHECK_TRUE(nullptr != root, common::kSqlAstError, "not an ASTCreateIndexStatement") diff --git a/hybridse/src/planv2/ast_node_converter.h b/hybridse/src/planv2/ast_node_converter.h index 631569156d2..edc0fb60c50 100644 --- a/hybridse/src/planv2/ast_node_converter.h +++ b/hybridse/src/planv2/ast_node_converter.h @@ -72,6 +72,12 @@ base::Status ConvertCreateUserStatement(const zetasql::ASTCreateUserStatement* r base::Status ConvertAlterUserStatement(const zetasql::ASTAlterUserStatement* root, node::NodeManager* node_manager, node::AlterUserNode** output); +base::Status ConvertGrantStatement(const zetasql::ASTGrantStatement* root, node::NodeManager* node_manager, + node::GrantNode** output); + +base::Status ConvertRevokeStatement(const zetasql::ASTRevokeStatement* root, node::NodeManager* node_manager, + node::RevokeNode** output); + base::Status ConvertQueryNode(const zetasql::ASTQuery* root, node::NodeManager* node_manager, node::QueryNode** output); base::Status ConvertQueryExpr(const zetasql::ASTQueryExpression* query_expr, node::NodeManager* node_manager, diff --git a/src/auth/user_access_manager.cc b/src/auth/user_access_manager.cc index d668a7dc497..1f354998ef3 100644 --- a/src/auth/user_access_manager.cc +++ b/src/auth/user_access_manager.cc @@ -47,7 +47,7 @@ void UserAccessManager::StopSyncTask() { void UserAccessManager::SyncWithDB() { if (auto it_pair = user_table_iterator_factory_(::openmldb::nameserver::USER_INFO_NAME); it_pair) { - auto new_user_map = std::make_unique>(); + auto new_user_map = std::make_unique>(); auto it = it_pair->first.get(); it->SeekToFirst(); while (it->Valid()) { @@ -56,13 +56,18 @@ void UserAccessManager::SyncWithDB() { auto size = it->GetValue().size(); codec::RowView row_view(*it_pair->second.get(), buf, size); std::string host, user, password; + std::string privilege_level_str; row_view.GetStrValue(0, &host); row_view.GetStrValue(1, &user); row_view.GetStrValue(2, &password); + row_view.GetStrValue(5, &privilege_level_str); + openmldb::nameserver::PrivilegeLevel privilege_level; + ::openmldb::nameserver::PrivilegeLevel_Parse(privilege_level_str, &privilege_level); + UserRecord user_record = {password, privilege_level}; if (host == "%") { - new_user_map->emplace(user, password); + new_user_map->emplace(user, user_record); } else { - new_user_map->emplace(FormUserHost(user, host), password); + new_user_map->emplace(FormUserHost(user, host), user_record); } it->Next(); } @@ -70,12 +75,36 @@ void UserAccessManager::SyncWithDB() { } } +std::optional UserAccessManager::GetUserPassword(const std::string& host, const std::string& user) { + if (auto user_record = user_map_.Get(FormUserHost(user, host)); user_record.has_value()) { + return user_record.value().password; + } else if (auto stored_password = user_map_.Get(user); stored_password.has_value()) { + return stored_password.value().password; + } else { + return std::nullopt; + } +} + bool UserAccessManager::IsAuthenticated(const std::string& host, const std::string& user, const std::string& password) { - if (auto stored_password = user_map_.Get(FormUserHost(user, host)); stored_password.has_value()) { - return stored_password.value() == password; + if (auto user_record = user_map_.Get(FormUserHost(user, host)); user_record.has_value()) { + return user_record.value().password == password; } else if (auto stored_password = user_map_.Get(user); stored_password.has_value()) { - return stored_password.value() == password; + return stored_password.value().password == password; } return false; } + +::openmldb::nameserver::PrivilegeLevel UserAccessManager::GetPrivilegeLevel(const std::string& user_at_host) { + std::size_t at_pos = user_at_host.find('@'); + if (at_pos != std::string::npos) { + std::string user = user_at_host.substr(0, at_pos); + std::string host = user_at_host.substr(at_pos + 1); + if (auto user_record = user_map_.Get(FormUserHost(user, host)); user_record.has_value()) { + return user_record.value().privilege_level; + } else if (auto stored_password = user_map_.Get(user); stored_password.has_value()) { + return stored_password.value().privilege_level; + } + } + return ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE; +} } // namespace openmldb::auth diff --git a/src/auth/user_access_manager.h b/src/auth/user_access_manager.h index 996efc326c4..9de6890f93a 100644 --- a/src/auth/user_access_manager.h +++ b/src/auth/user_access_manager.h @@ -26,9 +26,15 @@ #include #include "catalog/distribute_iterator.h" +#include "proto/name_server.pb.h" #include "refreshable_map.h" namespace openmldb::auth { +struct UserRecord { + std::string password; + ::openmldb::nameserver::PrivilegeLevel privilege_level; +}; + class UserAccessManager { public: using IteratorFactory = std::function GetUserPassword(const std::string& host, const std::string& user); private: IteratorFactory user_table_iterator_factory_; - RefreshableMap user_map_; + RefreshableMap user_map_; std::thread sync_task_thread_; std::promise stop_promise_; void StartSyncTask(); diff --git a/src/base/status.h b/src/base/status.h index c7e5ec75198..a2da254e78e 100644 --- a/src/base/status.h +++ b/src/base/status.h @@ -186,7 +186,8 @@ enum ReturnCode { kRPCError = 1004, // brpc controller error // auth - kFlushPrivilegesFailed = 1100 // brpc controller error + kFlushPrivilegesFailed = 1100, // brpc controller error + kNotAuthorized = 1101 // brpc controller error }; struct Status { diff --git a/src/client/ns_client.cc b/src/client/ns_client.cc index cdeef07e521..9a4c6f4df6d 100644 --- a/src/client/ns_client.cc +++ b/src/client/ns_client.cc @@ -317,6 +317,35 @@ bool NsClient::PutUser(const std::string& host, const std::string& name, const s return false; } +bool NsClient::PutPrivilege(const std::optional target_type, const std::string database, + const std::string target, const std::vector privileges, + const bool is_all_privileges, const std::vector grantees, + const ::openmldb::nameserver::PrivilegeLevel privilege_level) { + ::openmldb::nameserver::PutPrivilegeRequest request; + if (target_type.has_value()) { + request.set_target_type(target_type.value()); + } + request.set_database(database); + request.set_target(target); + for (const auto& privilege : privileges) { + request.add_privilege(privilege); + } + request.set_is_all_privileges(is_all_privileges); + for (const auto& grantee : grantees) { + request.add_grantee(grantee); + } + + request.set_privilege_level(privilege_level); + + ::openmldb::nameserver::GeneralResponse response; + bool ok = client_.SendRequest(&::openmldb::nameserver::NameServer_Stub::PutPrivilege, &request, &response, + FLAGS_request_timeout_ms, 1); + if (ok && response.code() == 0) { + return true; + } + return false; +} + bool NsClient::DeleteUser(const std::string& host, const std::string& name) { ::openmldb::nameserver::DeleteUserRequest request; request.set_host(host); diff --git a/src/client/ns_client.h b/src/client/ns_client.h index 73a52854765..1ddd50963bf 100644 --- a/src/client/ns_client.h +++ b/src/client/ns_client.h @@ -112,6 +112,11 @@ class NsClient : public Client { bool PutUser(const std::string& host, const std::string& name, const std::string& password); // NOLINT + bool PutPrivilege(const std::optional target_type, const std::string database, + const std::string target, const std::vector privileges, const bool is_all_privileges, + const std::vector grantees, + const ::openmldb::nameserver::PrivilegeLevel privilege_level); // NOLINT + bool DeleteUser(const std::string& host, const std::string& name); // NOLINT bool DropTable(const std::string& db, const std::string& name, diff --git a/src/cmd/sql_cmd_test.cc b/src/cmd/sql_cmd_test.cc index fe8faa21504..79225eb52dd 100644 --- a/src/cmd/sql_cmd_test.cc +++ b/src/cmd/sql_cmd_test.cc @@ -243,7 +243,6 @@ TEST_P(DBSDKTest, TestUser) { ASSERT_FALSE(status.IsOK()); sr->ExecuteSQL(absl::StrCat("CREATE USER IF NOT EXISTS user1"), &status); ASSERT_TRUE(status.IsOK()); - ASSERT_TRUE(true); auto opt = sr->GetRouterOptions(); if (cs->IsClusterMode()) { auto real_opt = std::dynamic_pointer_cast(opt); @@ -280,6 +279,121 @@ TEST_P(DBSDKTest, TestUser) { ASSERT_TRUE(status.IsOK()); } +TEST_P(DBSDKTest, TestGrantCreateUser) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + hybridse::sdk::Status status; + sr->ExecuteSQL(absl::StrCat("CREATE USER user1 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + auto opt = sr->GetRouterOptions(); + if (cs->IsClusterMode()) { + auto real_opt = std::dynamic_pointer_cast(opt); + sdk::SQLRouterOptions opt1; + opt1.zk_cluster = real_opt->zk_cluster; + opt1.zk_path = real_opt->zk_path; + opt1.user = "user1"; + opt1.password = "123456"; + auto router = NewClusterSQLRouter(opt1); + ASSERT_TRUE(router != nullptr); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_FALSE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("DROP USER user2"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("REVOKE CREATE USER ON *.* FROM 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user3 OPTIONS(password='123456')"), &status); + ASSERT_FALSE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("DROP USER user2"), &status); + ASSERT_FALSE(status.IsOK()); + } else { + auto real_opt = std::dynamic_pointer_cast(opt); + sdk::StandaloneOptions opt1; + opt1.host = real_opt->host; + opt1.port = real_opt->port; + opt1.user = "user1"; + opt1.password = "123456"; + auto router = NewStandaloneSQLRouter(opt1); + ASSERT_TRUE(router != nullptr); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_FALSE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("DROP USER user2"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("REVOKE CREATE USER ON *.* FROM 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user3 OPTIONS(password='123456')"), &status); + ASSERT_FALSE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("DROP USER user2"), &status); + ASSERT_FALSE(status.IsOK()); + } + sr->ExecuteSQL(absl::StrCat("DROP USER IF EXISTS user1"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("DROP USER IF EXISTS user2"), &status); + ASSERT_TRUE(status.IsOK()); +} + +TEST_P(DBSDKTest, TestGrantCreateUserGrantOption) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + hybridse::sdk::Status status; + sr->ExecuteSQL(absl::StrCat("CREATE USER user1 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + + auto opt = sr->GetRouterOptions(); + if (cs->IsClusterMode()) { + auto real_opt = std::dynamic_pointer_cast(opt); + sdk::SQLRouterOptions opt1; + opt1.zk_cluster = real_opt->zk_cluster; + opt1.zk_path = real_opt->zk_path; + opt1.user = "user1"; + opt1.password = "123456"; + auto router = NewClusterSQLRouter(opt1); + ASSERT_TRUE(router != nullptr); + router->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user2@%'"), &status); + ASSERT_FALSE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%' WITH GRANT OPTION"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user2@%'"), &status); + ASSERT_TRUE(status.IsOK()); + } else { + auto real_opt = std::dynamic_pointer_cast(opt); + sdk::StandaloneOptions opt1; + opt1.host = real_opt->host; + opt1.port = real_opt->port; + opt1.user = "user1"; + opt1.password = "123456"; + auto router = NewStandaloneSQLRouter(opt1); + ASSERT_TRUE(router != nullptr); + router->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user2@%'"), &status); + ASSERT_FALSE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%' WITH GRANT OPTION"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user2@%'"), &status); + ASSERT_TRUE(status.IsOK()); + } + sr->ExecuteSQL(absl::StrCat("DROP USER IF EXISTS user1"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("DROP USER IF EXISTS user2"), &status); + ASSERT_TRUE(status.IsOK()); +} + TEST_P(DBSDKTest, CreateDatabase) { auto cli = GetParam(); cs = cli->cs; diff --git a/src/nameserver/name_server_impl.cc b/src/nameserver/name_server_impl.cc index 9c565272fb3..5e65a7d2d94 100644 --- a/src/nameserver/name_server_impl.cc +++ b/src/nameserver/name_server_impl.cc @@ -1380,7 +1380,8 @@ void NameServerImpl::ShowTablet(RpcController* controller, const ShowTabletReque } base::Status NameServerImpl::PutUserRecord(const std::string& host, const std::string& user, - const std::string& password) { + const std::string& password, + const ::openmldb::nameserver::PrivilegeLevel privilege_level) { std::shared_ptr table_info; if (!GetTableInfo(USER_INFO_NAME, INTERNAL_DB, &table_info)) { return {ReturnCode::kTableIsNotExist, "user table does not exist"}; @@ -1391,12 +1392,8 @@ base::Status NameServerImpl::PutUserRecord(const std::string& host, const std::s row_values.push_back(user); row_values.push_back(password); row_values.push_back("0"); // password_last_changed - row_values.push_back("0"); // password_expired_time - row_values.push_back("0"); // create_time - row_values.push_back("0"); // update_time - row_values.push_back("1"); // account_type - row_values.push_back("0"); // privileges - row_values.push_back("null"); // extra_info + row_values.push_back("0"); // password_expired + row_values.push_back(PrivilegeLevel_Name(privilege_level)); // Create_user_priv std::string encoded_row; codec::RowCodec::EncodeRow(row_values, table_info->column_desc(), 1, encoded_row); @@ -1431,7 +1428,6 @@ base::Status NameServerImpl::DeleteUserRecord(const std::string& host, const std for (int meta_idx = 0; meta_idx < table_partition.partition_meta_size(); meta_idx++) { if (table_partition.partition_meta(meta_idx).is_leader() && table_partition.partition_meta(meta_idx).is_alive()) { - uint64_t cur_ts = ::baidu::common::timer::get_micros() / 1000; std::string endpoint = table_partition.partition_meta(meta_idx).endpoint(); auto table_ptr = GetTablet(endpoint); if (!table_ptr->client_->Delete(tid, 0, host + "|" + user, "index", msg)) { @@ -5640,7 +5636,8 @@ void NameServerImpl::OnLocked() { CreateDatabaseOrExit(INTERNAL_DB); if (db_table_info_[INTERNAL_DB].count(USER_INFO_NAME) == 0) { CreateSystemTableOrExit(SystemTableType::kUser); - PutUserRecord("%", "root", "1e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"); + PutUserRecord("%", "root", "1e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + ::openmldb::nameserver::PrivilegeLevel::PRIVILEGE_WITH_GRANT_OPTION); } if (IsClusterMode()) { if (tablets_.size() < FLAGS_system_table_replica_num) { @@ -9663,15 +9660,64 @@ NameServerImpl::GetSystemTableIterator() { void NameServerImpl::PutUser(RpcController* controller, const PutUserRequest* request, GeneralResponse* response, Closure* done) { brpc::ClosureGuard done_guard(done); - auto status = PutUserRecord(request->host(), request->name(), request->password()); - base::SetResponseStatus(status, response); + brpc::Controller* brpc_controller = static_cast(controller); + + if (brpc_controller->auth_context()->is_service() || + user_access_manager_.GetPrivilegeLevel(brpc_controller->auth_context()->user()) > + ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE) { + auto status = PutUserRecord(request->host(), request->name(), request->password(), + ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE); + base::SetResponseStatus(status, response); + } else { + base::SetResponseStatus(base::ReturnCode::kNotAuthorized, "not authorized to create user", response); + } +} + +void NameServerImpl::PutPrivilege(RpcController* controller, const PutPrivilegeRequest* request, + GeneralResponse* response, Closure* done) { + brpc::ClosureGuard done_guard(done); + + for (int i = 0; i < request->privilege_size(); ++i) { + auto privilege = request->privilege(i); + if (privilege == "CREATE USER") { + brpc::Controller* brpc_controller = static_cast(controller); + if (brpc_controller->auth_context()->is_service() || + user_access_manager_.GetPrivilegeLevel(brpc_controller->auth_context()->user()) >= + ::openmldb::nameserver::PrivilegeLevel::PRIVILEGE_WITH_GRANT_OPTION) { + for (int i = 0; i < request->grantee_size(); ++i) { + auto grantee = request->grantee(i); + std::size_t at_pos = grantee.find('@'); + if (at_pos != std::string::npos) { + std::string user = grantee.substr(0, at_pos); + std::string host = grantee.substr(at_pos + 1); + auto password = user_access_manager_.GetUserPassword(host, user); + if (password.has_value()) { + auto status = PutUserRecord(host, user, password.value(), request->privilege_level()); + base::SetResponseStatus(status, response); + } + } + } + } else { + base::SetResponseStatus(base::ReturnCode::kNotAuthorized, + "not authorized to grant create user privilege", response); + } + } + } } void NameServerImpl::DeleteUser(RpcController* controller, const DeleteUserRequest* request, GeneralResponse* response, Closure* done) { brpc::ClosureGuard done_guard(done); - auto status = DeleteUserRecord(request->host(), request->name()); - base::SetResponseStatus(status, response); + brpc::Controller* brpc_controller = static_cast(controller); + + if (brpc_controller->auth_context()->is_service() || + user_access_manager_.GetPrivilegeLevel(brpc_controller->auth_context()->user()) > + ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE) { + auto status = DeleteUserRecord(request->host(), request->name()); + base::SetResponseStatus(status, response); + } else { + base::SetResponseStatus(base::ReturnCode::kNotAuthorized, "not authorized to create user", response); + } } bool NameServerImpl::IsAuthenticated(const std::string& host, const std::string& username, diff --git a/src/nameserver/name_server_impl.h b/src/nameserver/name_server_impl.h index dadc335c7a3..53f44cde278 100644 --- a/src/nameserver/name_server_impl.h +++ b/src/nameserver/name_server_impl.h @@ -360,6 +360,8 @@ class NameServerImpl : public NameServer { Closure* done); void PutUser(RpcController* controller, const PutUserRequest* request, GeneralResponse* response, Closure* done); + void PutPrivilege(RpcController* controller, const PutPrivilegeRequest* request, GeneralResponse* response, + Closure* done); void DeleteUser(RpcController* controller, const DeleteUserRequest* request, GeneralResponse* response, Closure* done); bool IsAuthenticated(const std::string& host, const std::string& username, const std::string& password); @@ -373,7 +375,8 @@ class NameServerImpl : public NameServer { bool GetTableInfo(const std::string& table_name, const std::string& db_name, std::shared_ptr* table_info); - base::Status PutUserRecord(const std::string& host, const std::string& user, const std::string& password); + base::Status PutUserRecord(const std::string& host, const std::string& user, const std::string& password, + const ::openmldb::nameserver::PrivilegeLevel privilege_level); base::Status DeleteUserRecord(const std::string& host, const std::string& user); base::Status FlushPrivileges(); diff --git a/src/nameserver/system_table.h b/src/nameserver/system_table.h index cda34e1798e..03c8bc2364e 100644 --- a/src/nameserver/system_table.h +++ b/src/nameserver/system_table.h @@ -163,20 +163,16 @@ class SystemTable { break; } case SystemTableType::kUser: { - SetColumnDesc("host", type::DataType::kString, table_info->add_column_desc()); - SetColumnDesc("user", type::DataType::kString, table_info->add_column_desc()); - SetColumnDesc("password", type::DataType::kString, table_info->add_column_desc()); + SetColumnDesc("Host", type::DataType::kString, table_info->add_column_desc()); + SetColumnDesc("User", type::DataType::kString, table_info->add_column_desc()); + SetColumnDesc("authentication_string", type::DataType::kString, table_info->add_column_desc()); SetColumnDesc("password_last_changed", type::DataType::kTimestamp, table_info->add_column_desc()); - SetColumnDesc("password_expired_time", type::DataType::kBigInt, table_info->add_column_desc()); - SetColumnDesc("create_time", type::DataType::kTimestamp, table_info->add_column_desc()); - SetColumnDesc("update_time", type::DataType::kTimestamp, table_info->add_column_desc()); - SetColumnDesc("account_type", type::DataType::kInt, table_info->add_column_desc()); - SetColumnDesc("privileges", type::DataType::kString, table_info->add_column_desc()); - SetColumnDesc("extra_info", type::DataType::kString, table_info->add_column_desc()); + SetColumnDesc("password_expired", type::DataType::kTimestamp, table_info->add_column_desc()); + SetColumnDesc("Create_user_priv", type::DataType::kString, table_info->add_column_desc()); auto index = table_info->add_column_key(); index->set_index_name("index"); - index->add_col_name("host"); - index->add_col_name("user"); + index->add_col_name("Host"); + index->add_col_name("User"); auto ttl = index->mutable_ttl(); ttl->set_ttl_type(::openmldb::type::kLatestTime); ttl->set_lat_ttl(1); diff --git a/src/proto/name_server.proto b/src/proto/name_server.proto index f7c8fd5c830..14cd00d6ddd 100755 --- a/src/proto/name_server.proto +++ b/src/proto/name_server.proto @@ -544,6 +544,22 @@ message DeleteUserRequest { required string name = 2; } +enum PrivilegeLevel { + NO_PRIVILEGE = 0; + PRIVILEGE = 1; + PRIVILEGE_WITH_GRANT_OPTION = 2; +} + +message PutPrivilegeRequest { + repeated string grantee = 1; + repeated string privilege = 2; + optional string target_type = 3; + required string database = 4; + required string target = 5; + required bool is_all_privileges = 6; + required PrivilegeLevel privilege_level = 7; +} + message DeploySQLRequest { optional openmldb.api.ProcedureInfo sp_info = 3; repeated TableIndex index = 4; @@ -617,4 +633,7 @@ service NameServer { // user related interfaces rpc PutUser(PutUserRequest) returns (GeneralResponse); rpc DeleteUser(DeleteUserRequest) returns (GeneralResponse); + + // authz related interfaces + rpc PutPrivilege(PutPrivilegeRequest) returns (GeneralResponse); } diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index dbdd7dede9d..3d09156fdcc 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -2786,6 +2786,30 @@ std::shared_ptr SQLClusterRouter::ExecuteSQL( } return {}; } + case hybridse::node::kPlanTypeGrant: { + auto grant_node = dynamic_cast(node); + auto ns = cluster_sdk_->GetNsClient(); + auto ok = ns->PutPrivilege(grant_node->TargetType(), grant_node->Database(), grant_node->Target(), + grant_node->Privileges(), grant_node->IsAllPrivileges(), grant_node->Grantees(), + grant_node->WithGrantOption() + ? ::openmldb::nameserver::PrivilegeLevel::PRIVILEGE_WITH_GRANT_OPTION + : ::openmldb::nameserver::PrivilegeLevel::PRIVILEGE); + if (!ok) { + *status = {StatusCode::kCmdError, "Grant API call failed"}; + } + return {}; + } + case hybridse::node::kPlanTypeRevoke: { + auto revoke_node = dynamic_cast(node); + auto ns = cluster_sdk_->GetNsClient(); + auto ok = ns->PutPrivilege(revoke_node->TargetType(), revoke_node->Database(), revoke_node->Target(), + revoke_node->Privileges(), revoke_node->IsAllPrivileges(), + revoke_node->Grantees(), ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE); + if (!ok) { + *status = {StatusCode::kCmdError, "Revoke API call failed"}; + } + return {}; + } case hybridse::node::kPlanTypeAlterUser: { auto alter_node = dynamic_cast(node); UserInfo user_info; From 289b746bc302388605d31c2c8d4e256a54f0ca46 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 26 Jun 2024 18:26:01 +0800 Subject: [PATCH 7/9] build(deps-dev): bump requests from 2.31.0 to 2.32.2 in /docs (#3951) Bumps [requests](https://github.com/psf/requests) from 2.31.0 to 2.32.2. - [Release notes](https://github.com/psf/requests/releases) - [Changelog](https://github.com/psf/requests/blob/main/HISTORY.md) - [Commits](https://github.com/psf/requests/compare/v2.31.0...v2.32.2) --- updated-dependencies: - dependency-name: requests dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/poetry.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/poetry.lock b/docs/poetry.lock index 39577275304..2e0e52f5a21 100644 --- a/docs/poetry.lock +++ b/docs/poetry.lock @@ -425,13 +425,13 @@ files = [ [[package]] name = "requests" -version = "2.31.0" +version = "2.32.2" description = "Python HTTP for Humans." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, - {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, + {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"}, + {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"}, ] [package.dependencies] From ca7f7e2085c3634e2087668d88b6fa320a75190a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 26 Jun 2024 18:26:12 +0800 Subject: [PATCH 8/9] build(deps-dev): bump org.apache.derby:derby (#3949) Bumps org.apache.derby:derby from 10.14.2.0 to 10.17.1.0. --- updated-dependencies: - dependency-name: org.apache.derby:derby dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- extensions/kafka-connect-jdbc/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/kafka-connect-jdbc/pom.xml b/extensions/kafka-connect-jdbc/pom.xml index a0c3cf0512d..2f35d1bd1e1 100644 --- a/extensions/kafka-connect-jdbc/pom.xml +++ b/extensions/kafka-connect-jdbc/pom.xml @@ -53,7 +53,7 @@ - 10.14.2.0 + 10.17.1.0 2.7 0.11.1 3.41.2.2 From 25bd745badeefd54d3f8b0c75c8c5d69ceb58a42 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 26 Jun 2024 18:26:18 +0800 Subject: [PATCH 9/9] build(deps): bump org.postgresql:postgresql (#3950) Bumps [org.postgresql:postgresql](https://github.com/pgjdbc/pgjdbc) from 42.3.3 to 42.3.9. - [Release notes](https://github.com/pgjdbc/pgjdbc/releases) - [Changelog](https://github.com/pgjdbc/pgjdbc/blob/master/CHANGELOG.md) - [Commits](https://github.com/pgjdbc/pgjdbc/compare/REL42.3.3...REL42.3.9) --- updated-dependencies: - dependency-name: org.postgresql:postgresql dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- extensions/kafka-connect-jdbc/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/kafka-connect-jdbc/pom.xml b/extensions/kafka-connect-jdbc/pom.xml index 2f35d1bd1e1..5d911a342c1 100644 --- a/extensions/kafka-connect-jdbc/pom.xml +++ b/extensions/kafka-connect-jdbc/pom.xml @@ -59,7 +59,7 @@ 3.41.2.2 19.7.0.0 8.4.1.jre8 - 42.3.3 + 42.3.9 1.3.1 0.8.5 Confluent Community License