From 1632b3ac5bfe757fa18dc304f3dee8c78cb1e7f0 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Mon, 6 May 2024 19:09:35 +0800 Subject: [PATCH 01/23] docs: fix example (#3907) raw SQL request mode example was wrong because execute_mode should be request --- docs/en/openmldb_sql/dql/SELECT_STATEMENT.md | 2 +- docs/zh/openmldb_sql/dql/SELECT_STATEMENT.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/en/openmldb_sql/dql/SELECT_STATEMENT.md b/docs/en/openmldb_sql/dql/SELECT_STATEMENT.md index 71e3032a57a..6aa11cf7d1a 100644 --- a/docs/en/openmldb_sql/dql/SELECT_STATEMENT.md +++ b/docs/en/openmldb_sql/dql/SELECT_STATEMENT.md @@ -142,6 +142,6 @@ Parentheses `()` expression is the minimal unit to a request row, every expressi -- executing SQL as request mode, with request row (10, "foo", timestamp(4000)) SELECT id, count (val) over (partition by id order by ts rows between 10 preceding and current row) FROM t1 -CONFIG (execute_mode = 'online', values = (10, "foo", timestamp (4000))) +CONFIG (execute_mode = 'request', values = (10, "foo", timestamp (4000))) ``` diff --git a/docs/zh/openmldb_sql/dql/SELECT_STATEMENT.md b/docs/zh/openmldb_sql/dql/SELECT_STATEMENT.md index d26d1c9fd96..4092c0c69c6 100644 --- a/docs/zh/openmldb_sql/dql/SELECT_STATEMENT.md +++ b/docs/zh/openmldb_sql/dql/SELECT_STATEMENT.md @@ -153,7 +153,7 @@ OpenMLDB >= 0.9.0 支持在 query statement 中用 CONFIG 子句配置 SQL 的 -- 执行请求行为 (10, "foo", timestamp(4000)) 的在线请求模式 query SELECT id, count (val) over (partition by id order by ts rows between 10 preceding and current row) FROM t1 -CONFIG (execute_mode = 'online', values = (10, "foo", timestamp (4000))) +CONFIG (execute_mode = 'request', values = (10, "foo", timestamp (4000))) ``` ## 离线同步模式 Query From 8ce7d727117164c292a4eed819d103a7dde32e71 Mon Sep 17 00:00:00 2001 From: oh2024 <162292688+oh2024@users.noreply.github.com> Date: Tue, 7 May 2024 18:03:42 +0800 Subject: [PATCH 02/23] fix: make clients use always send auth info (#3906) * fix: make clients use auth by default * fix: let skip auth flag only affect verify --- src/auth/brpc_authenticator.cc | 4 ++++ src/cmd/openmldb.cc | 29 ++++++++++++----------------- src/nameserver/name_server_impl.cc | 12 +++++------- src/rpc/rpc_client.h | 4 +--- src/sdk/mini_cluster.h | 1 - src/tablet/file_sender.cc | 4 +--- 6 files changed, 23 insertions(+), 31 deletions(-) diff --git a/src/auth/brpc_authenticator.cc b/src/auth/brpc_authenticator.cc index f1964334c3f..4a56ebf0165 100644 --- a/src/auth/brpc_authenticator.cc +++ b/src/auth/brpc_authenticator.cc @@ -18,6 +18,7 @@ #include "auth_utils.h" #include "butil/endpoint.h" +#include "nameserver/system_table.h" namespace openmldb::authn { @@ -37,6 +38,9 @@ int BRPCAuthenticator::GenerateCredential(std::string* auth_str) const { int BRPCAuthenticator::VerifyCredential(const std::string& auth_str, const butil::EndPoint& client_addr, brpc::AuthContext* out_ctx) const { + if (FLAGS_skip_grant_tables) { + return 0; + } if (auth_str.length() < 2) { return -1; } diff --git a/src/cmd/openmldb.cc b/src/cmd/openmldb.cc index 0f1bd920d2b..74017d790fa 100644 --- a/src/cmd/openmldb.cc +++ b/src/cmd/openmldb.cc @@ -149,15 +149,12 @@ void StartNameServer() { brpc::ServerOptions options; std::unique_ptr user_access_manager; std::unique_ptr server_authenticator; - if (!FLAGS_skip_grant_tables) { - user_access_manager = - std::make_unique(name_server->GetSystemTableIterator()); - server_authenticator = std::make_unique( - [&user_access_manager](const std::string& host, const std::string& username, const std::string& password) { - return user_access_manager->IsAuthenticated(host, username, password); - }); - options.auth = server_authenticator.get(); - } + user_access_manager = std::make_unique(name_server->GetSystemTableIterator()); + server_authenticator = std::make_unique( + [&user_access_manager](const std::string& host, const std::string& username, const std::string& password) { + return user_access_manager->IsAuthenticated(host, username, password); + }); + options.auth = server_authenticator.get(); options.num_threads = FLAGS_thread_pool_size; brpc::Server server; @@ -259,14 +256,12 @@ void StartTablet() { std::unique_ptr user_access_manager; std::unique_ptr server_authenticator; - if (!FLAGS_skip_grant_tables) { - user_access_manager = std::make_unique(tablet->GetSystemTableIterator()); - server_authenticator = std::make_unique( - [&user_access_manager](const std::string& host, const std::string& username, const std::string& password) { - return user_access_manager->IsAuthenticated(host, username, password); - }); - options.auth = server_authenticator.get(); - } + user_access_manager = std::make_unique(tablet->GetSystemTableIterator()); + server_authenticator = std::make_unique( + [&user_access_manager](const std::string& host, const std::string& username, const std::string& password) { + return user_access_manager->IsAuthenticated(host, username, password); + }); + options.auth = server_authenticator.get(); options.num_threads = FLAGS_thread_pool_size; brpc::Server server; if (server.AddService(tablet, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) { diff --git a/src/nameserver/name_server_impl.cc b/src/nameserver/name_server_impl.cc index 86b0fb47b21..ff1db103a29 100644 --- a/src/nameserver/name_server_impl.cc +++ b/src/nameserver/name_server_impl.cc @@ -1520,12 +1520,10 @@ bool NameServerImpl::Init(const std::string& zk_cluster, const std::string& zk_p task_vec_.resize(FLAGS_name_server_task_max_concurrency + FLAGS_name_server_task_concurrency_for_replica_cluster); task_thread_pool_.DelayTask(FLAGS_make_snapshot_check_interval, boost::bind(&NameServerImpl::SchedMakeSnapshot, this)); - if (!FLAGS_skip_grant_tables) { - std::shared_ptr<::openmldb::nameserver::TableInfo> table_info; - while ( - !GetTableInfo(::openmldb::nameserver::USER_INFO_NAME, ::openmldb::nameserver::INTERNAL_DB, &table_info)) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } + std::shared_ptr<::openmldb::nameserver::TableInfo> table_info; + while ( + !GetTableInfo(::openmldb::nameserver::USER_INFO_NAME, ::openmldb::nameserver::INTERNAL_DB, &table_info)) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); } return true; } @@ -5593,7 +5591,7 @@ void NameServerImpl::OnLocked() { PDLOG(WARNING, "recover failed"); } CreateDatabaseOrExit(INTERNAL_DB); - if (!FLAGS_skip_grant_tables && db_table_info_[INTERNAL_DB].count(USER_INFO_NAME) == 0) { + if (db_table_info_[INTERNAL_DB].count(USER_INFO_NAME) == 0) { auto temp = FLAGS_system_table_replica_num; FLAGS_system_table_replica_num = tablets_.size(); CreateSystemTableOrExit(SystemTableType::kUser); diff --git a/src/rpc/rpc_client.h b/src/rpc/rpc_client.h index 0fbfdf757fa..92e27279891 100644 --- a/src/rpc/rpc_client.h +++ b/src/rpc/rpc_client.h @@ -104,9 +104,7 @@ class RpcClient { if (use_sleep_policy_) { options.retry_policy = &sleep_retry_policy; } - if (!FLAGS_skip_grant_tables) { - options.auth = &client_authenticator_; - } + options.auth = &client_authenticator_; if (channel_->Init(endpoint_.c_str(), "", &options) != 0) { return -1; diff --git a/src/sdk/mini_cluster.h b/src/sdk/mini_cluster.h index 2851e111cab..673e1cb1f61 100644 --- a/src/sdk/mini_cluster.h +++ b/src/sdk/mini_cluster.h @@ -365,7 +365,6 @@ class StandaloneEnv { }); brpc::ServerOptions options; options.auth = ns_authenticator_; - options.auth = ns_authenticator_; if (ns_.AddService(nameserver, brpc::SERVER_OWNS_SERVICE) != 0) { LOG(WARNING) << "fail to add ns"; return false; diff --git a/src/tablet/file_sender.cc b/src/tablet/file_sender.cc index 19ccf9bedcc..851b28539c2 100644 --- a/src/tablet/file_sender.cc +++ b/src/tablet/file_sender.cc @@ -64,9 +64,7 @@ bool FileSender::Init() { } channel_ = new brpc::Channel(); brpc::ChannelOptions options; - if (!FLAGS_skip_grant_tables) { - options.auth = &client_authenticator_; - } + options.auth = &client_authenticator_; options.timeout_ms = FLAGS_request_timeout_ms; options.connect_timeout_ms = FLAGS_request_timeout_ms; options.max_retry = FLAGS_request_max_retry; From 72691413415de11efc1c66ed1141f850bf5cf8e3 Mon Sep 17 00:00:00 2001 From: oh2024 <162292688+oh2024@users.noreply.github.com> Date: Wed, 8 May 2024 11:38:13 +0800 Subject: [PATCH 03/23] feat: tablets get user table remotely (#3918) * fix: make clients use auth by default * fix: let skip auth flag only affect verify * feat: tablets get user table remotely * fix: use FLAGS_system_table_replica_num for user table --- src/cmd/sql_cmd_test.cc | 2 +- .../name_server_create_remote_test.cc | 1 - src/nameserver/name_server_impl.cc | 3 -- src/nameserver/new_server_env_test.cc | 1 - src/sdk/sql_cluster_router.cc | 1 + src/tablet/procedure_drop_test.cc | 1 - src/tablet/procedure_recover_test.cc | 1 - src/tablet/tablet_impl.cc | 43 +++++++++++++------ 8 files changed, 33 insertions(+), 20 deletions(-) diff --git a/src/cmd/sql_cmd_test.cc b/src/cmd/sql_cmd_test.cc index 4ba3615b96d..cedda42a6cd 100644 --- a/src/cmd/sql_cmd_test.cc +++ b/src/cmd/sql_cmd_test.cc @@ -3545,7 +3545,7 @@ TEST_P(DBSDKTest, ShowComponents) { void ExpectShowTableStatusResult(const std::vector>& expect, hybridse::sdk::ResultSet* rs, bool all_db = false, bool is_cluster = false) { static const std::vector> SystemClusterTableStatus = { - {{}, "USER", "__INTERNAL_DB", "memory", {}, {}, {}, "1", "0", "2", "NULL", "NULL", "NULL", ""}, + {{}, "USER", "__INTERNAL_DB", "memory", {}, {}, {}, "1", "0", "1", "NULL", "NULL", "NULL", ""}, {{}, "PRE_AGG_META_INFO", "__INTERNAL_DB", "memory", {}, {}, {}, "1", "0", "1", "NULL", "NULL", "NULL", ""}, {{}, "JOB_INFO", "__INTERNAL_DB", "memory", "0", {}, {}, "1", "0", "1", "NULL", "NULL", "NULL", ""}, {{}, diff --git a/src/nameserver/name_server_create_remote_test.cc b/src/nameserver/name_server_create_remote_test.cc index 4560f9dade6..fa87cba5d61 100644 --- a/src/nameserver/name_server_create_remote_test.cc +++ b/src/nameserver/name_server_create_remote_test.cc @@ -1359,6 +1359,5 @@ int main(int argc, char** argv) { ::openmldb::base::SetLogLevel(INFO); ::google::ParseCommandLineFlags(&argc, &argv, true); ::openmldb::test::InitRandomDiskFlags("name_server_create_remote_test"); - FLAGS_system_table_replica_num = 0; return RUN_ALL_TESTS(); } diff --git a/src/nameserver/name_server_impl.cc b/src/nameserver/name_server_impl.cc index ff1db103a29..871adcb8d49 100644 --- a/src/nameserver/name_server_impl.cc +++ b/src/nameserver/name_server_impl.cc @@ -5592,10 +5592,7 @@ void NameServerImpl::OnLocked() { } CreateDatabaseOrExit(INTERNAL_DB); if (db_table_info_[INTERNAL_DB].count(USER_INFO_NAME) == 0) { - auto temp = FLAGS_system_table_replica_num; - FLAGS_system_table_replica_num = tablets_.size(); CreateSystemTableOrExit(SystemTableType::kUser); - FLAGS_system_table_replica_num = temp; InsertUserRecord("%", "root", "1e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"); } if (IsClusterMode()) { diff --git a/src/nameserver/new_server_env_test.cc b/src/nameserver/new_server_env_test.cc index 1bb364a0de9..b6a1130b0a8 100644 --- a/src/nameserver/new_server_env_test.cc +++ b/src/nameserver/new_server_env_test.cc @@ -467,6 +467,5 @@ int main(int argc, char** argv) { ::openmldb::base::SetLogLevel(INFO); ::google::ParseCommandLineFlags(&argc, &argv, true); ::openmldb::test::InitRandomDiskFlags("new_server_env_test"); - FLAGS_system_table_replica_num = 0; return RUN_ALL_TESTS(); } diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index 8be810f1559..5e0422e9faa 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -1019,6 +1019,7 @@ std::shared_ptr SQLClusterRouter::GetSQLCache(const std::string& db, c } return router_cache; } + std::shared_ptr<::openmldb::client::TabletClient> SQLClusterRouter::GetTabletClient( const std::string& db, const std::string& sql, const ::hybridse::vm::EngineMode engine_mode, const std::shared_ptr& row, hybridse::sdk::Status* status) { diff --git a/src/tablet/procedure_drop_test.cc b/src/tablet/procedure_drop_test.cc index de43a2e02dc..7f6c92654cb 100644 --- a/src/tablet/procedure_drop_test.cc +++ b/src/tablet/procedure_drop_test.cc @@ -297,6 +297,5 @@ int main(int argc, char** argv) { ::openmldb::base::SetLogLevel(INFO); ::google::ParseCommandLineFlags(&argc, &argv, true); ::openmldb::test::InitRandomDiskFlags("procedure_recover_test"); - FLAGS_system_table_replica_num = 0; return RUN_ALL_TESTS(); } diff --git a/src/tablet/procedure_recover_test.cc b/src/tablet/procedure_recover_test.cc index eef6d87669b..be46831b1f9 100644 --- a/src/tablet/procedure_recover_test.cc +++ b/src/tablet/procedure_recover_test.cc @@ -270,6 +270,5 @@ int main(int argc, char** argv) { ::openmldb::base::SetLogLevel(INFO); ::google::ParseCommandLineFlags(&argc, &argv, true); ::openmldb::test::InitRandomDiskFlags("recover_procedure_test"); - FLAGS_system_table_replica_num = 0; return RUN_ALL_TESTS(); } diff --git a/src/tablet/tablet_impl.cc b/src/tablet/tablet_impl.cc index 59e2321de9b..7cf7420f32c 100644 --- a/src/tablet/tablet_impl.cc +++ b/src/tablet/tablet_impl.cc @@ -5819,19 +5819,38 @@ TabletImpl::GetSystemTableIterator() { return [this](const std::string& table_name) -> std::optional, std::unique_ptr>> { - for (const auto& [tid, tables] : tables_) { - for (const auto& [pid, table] : tables) { - if (table->GetName() == table_name) { - std::map> empty_tablet_clients; - auto user_table = std::make_shared>>( - std::map>{{pid, table}}); - return {{std::make_unique<::openmldb::catalog::FullTableIterator>(table->GetId(), user_table, - empty_tablet_clients), - std::make_unique<::openmldb::codec::Schema>(table->GetTableMeta()->column_desc())}}; - } - } + auto handler = catalog_->GetTable(::openmldb::nameserver::INTERNAL_DB, ::openmldb::nameserver::USER_INFO_NAME); + if (!handler) { + PDLOG(WARNING, "no user table tablehandler"); + return std::nullopt; + } + auto tablet_table_handler = std::dynamic_pointer_cast(handler); + if (!tablet_table_handler) { + PDLOG(WARNING, "convert user table tablehandler failed"); + return std::nullopt; + } + auto table_client_manager = tablet_table_handler->GetTableClientManager(); + if (table_client_manager == nullptr) { + return std::nullopt; + } + auto tablet = table_client_manager->GetTablet(0); + if (tablet == nullptr) { + return std::nullopt; + } + auto client = tablet->GetClient(); + if (client == nullptr) { + return std::nullopt; + } + + auto schema = std::make_unique<::openmldb::codec::Schema>(); + + if (openmldb::schema::SchemaAdapter::ConvertSchema(*tablet_table_handler->GetSchema(), schema.get())) { + std::map> tablet_clients = {{0, client}}; + return {{std::make_unique(tablet_table_handler->GetTid(), nullptr, tablet_clients), + std::move(schema)}}; + } else { + return std::nullopt; } - return std::nullopt; }; } From c5ecca765591aa747f6b417e4139c22f5c568846 Mon Sep 17 00:00:00 2001 From: HuangWei Date: Thu, 9 May 2024 19:00:18 +0800 Subject: [PATCH 04/23] fix: recoverdata support load disk table (#3888) --- docs/en/maintain/cli.md | 2 +- docs/zh/maintain/cli.md | 2 +- src/cmd/openmldb.cc | 1 - src/tablet/tablet_impl.cc | 1 + tools/openmldb_ops.py | 61 +++++++++++++----------- tools/tool.py | 98 +++++++++++++++++++++++++++++---------- 6 files changed, 111 insertions(+), 54 deletions(-) diff --git a/docs/en/maintain/cli.md b/docs/en/maintain/cli.md index c1ffed905b1..c0c9ae67555 100644 --- a/docs/en/maintain/cli.md +++ b/docs/en/maintain/cli.md @@ -401,7 +401,7 @@ $ ./openmldb --endpoint=172.27.2.52:9520 --role=client ### loadtable -1. Load an existing table +Load an existing table, only support memory table Command format: `loadtable table_name tid pid ttl segment_cnt` diff --git a/docs/zh/maintain/cli.md b/docs/zh/maintain/cli.md index 4cab9249bd7..e30d0cc1b77 100644 --- a/docs/zh/maintain/cli.md +++ b/docs/zh/maintain/cli.md @@ -395,7 +395,7 @@ $ ./openmldb --endpoint=172.27.2.52:9520 --role=client ### loadtable -1、加载已有表 +加载已有表,只支持内存表 命令格式: loadtable table\_name tid pid ttl segment\_cnt diff --git a/src/cmd/openmldb.cc b/src/cmd/openmldb.cc index 74017d790fa..b13694d8d3c 100644 --- a/src/cmd/openmldb.cc +++ b/src/cmd/openmldb.cc @@ -3260,7 +3260,6 @@ void HandleClientLoadTable(const std::vector parts, ::openmldb::cli return; } } - // TODO(): get status msg auto st = client->LoadTable(parts[1], boost::lexical_cast(parts[2]), boost::lexical_cast(parts[3]), ttl, is_leader, seg_cnt); if (st.OK()) { diff --git a/src/tablet/tablet_impl.cc b/src/tablet/tablet_impl.cc index 7cf7420f32c..8c59a4f9184 100644 --- a/src/tablet/tablet_impl.cc +++ b/src/tablet/tablet_impl.cc @@ -3039,6 +3039,7 @@ void TabletImpl::LoadTable(RpcController* controller, const ::openmldb::api::Loa break; } std::string root_path; + // we can't know table is memory or disk, so set the right storage_mode in request message bool ok = ChooseDBRootPath(tid, pid, table_meta.storage_mode(), root_path); if (!ok) { response->set_code(::openmldb::base::ReturnCode::kFailToGetDbRootPath); diff --git a/tools/openmldb_ops.py b/tools/openmldb_ops.py index f3069254a65..1c78f53de3f 100644 --- a/tools/openmldb_ops.py +++ b/tools/openmldb_ops.py @@ -97,41 +97,43 @@ def CheckTable(executor, db, table_name): return Status(-1, "role is not match") return Status() -def RecoverPartition(executor, db, partitions, endpoint_status): +def RecoverPartition(executor, db, replicas, endpoint_status, storage): + """recover all replicas of one partition""" leader_pos = -1 max_offset = 0 - table_name = partitions[0].GetName() - pid = partitions[0].GetPid() - for pos in range(len(partitions)): - partition = partitions[pos] - if partition.IsLeader() and partition.GetOffset() >= max_offset: + table_name = replicas[0].GetName() + pid = replicas[0].GetPid() + tid = replicas[0].GetTid() + for pos in range(len(replicas)): + replica = replicas[pos] + if replica.IsLeader() and replica.GetOffset() >= max_offset: leader_pos = pos if leader_pos < 0: - log.error("cannot find leader partition. db {db} name {table_name} partition {pid}".format( - db=db, table_name=table_name, pid=pid)) - return Status(-1, "recover partition failed") - tid = partitions[0].GetTid() - leader_endpoint = partitions[leader_pos].GetEndpoint() + msg = "cannot find leader replica. db {db} name {table_name} partition {pid}".format( + db=db, table_name=table_name, pid=pid) + log.error(msg) + return Status(-1, "recover partition failed: {msg}".format(msg=msg)) + leader_endpoint = replicas[leader_pos].GetEndpoint() # recover leader if "{tid}_{pid}".format(tid=tid, pid=pid) not in endpoint_status[leader_endpoint]: - log.info("leader partition is not in tablet, db {db} name {table_name} pid {pid} endpoint {leader_endpoint}. start loading data...".format( + log.info("leader replica is not in tablet, db {db} name {table_name} pid {pid} endpoint {leader_endpoint}. start loading data...".format( db=db, table_name=table_name, pid=pid, leader_endpoint=leader_endpoint)) - status = executor.LoadTable(leader_endpoint, table_name, tid, pid) + status = executor.LoadTableHTTP(leader_endpoint, table_name, tid, pid, storage) if not status.OK(): log.error("load table failed. db {db} name {table_name} tid {tid} pid {pid} endpoint {leader_endpoint} msg {status}".format( db=db, table_name=table_name, tid=tid, pid=pid, leader_endpoint=leader_endpoint, status=status.GetMsg())) - return Status(-1, "recover partition failed") - if not partitions[leader_pos].IsAlive(): + return status + if not replicas[leader_pos].IsAlive(): status = executor.UpdateTableAlive(db, table_name, pid, leader_endpoint, "yes") if not status.OK(): log.error("update leader alive failed. db {db} name {table_name} pid {pid} endpoint {leader_endpoint}".format( db=db, table_name=table_name, pid=pid, leader_endpoint=leader_endpoint)) return Status(-1, "recover partition failed") # recover follower - for pos in range(len(partitions)): + for pos in range(len(replicas)): if pos == leader_pos: continue - partition = partitions[pos] + partition = replicas[pos] endpoint = partition.GetEndpoint() if partition.IsAlive(): status = executor.UpdateTableAlive(db, table_name, pid, endpoint, "no") @@ -149,14 +151,21 @@ def RecoverTable(executor, db, table_name): log.info("{table_name} in {db} is healthy".format(table_name=table_name, db=db)) return Status() log.info("recover {table_name} in {db}".format(table_name=table_name, db=db)) - status, table_info = executor.GetTableInfo(db, table_name) + status, table_info = executor.GetTableInfoHTTP(db, table_name) if not status.OK(): - log.warning("get table info failed. msg is {msg}".format(msg=status.GetMsg())) - return Status(-1, "get table info failed. msg is {msg}".format(msg=status.GetMsg())) - partition_dict = executor.ParseTableInfo(table_info) + log.warning("get table info failed. msg is {msg}".format(msg=status)) + return Status(-1, "get table info failed. msg is {msg}".format(msg=status)) + if len(table_info) != 1: + log.warning("table info should be 1, {table_info}".format(table_info=table_info)) + return Status(-1, "table info should be 1") + table_info = table_info[0] + partition_dict = executor.ParseTableInfoJson(table_info) + storage = "kMemory" if "storage_mode" not in table_info else table_info["storage_mode"] endpoints = set() - for record in table_info: - endpoints.add(record[3]) + for _, reps in partition_dict.items(): + # list of replicas + for rep in reps: + endpoints.add(rep.GetEndpoint()) endpoint_status = {} for endpoint in endpoints: status, result = executor.GetTableStatus(endpoint) @@ -164,9 +173,9 @@ def RecoverTable(executor, db, table_name): log.warning("get table status failed. msg is {msg}".format(msg=status.GetMsg())) return Status(-1, "get table status failed. msg is {msg}".format(msg=status.GetMsg())) endpoint_status[endpoint] = result - max_pid = int(table_info[-1][2]) - for pid in range(max_pid + 1): - RecoverPartition(executor, db, partition_dict[str(pid)], endpoint_status) + + for _, part in partition_dict.items(): + RecoverPartition(executor, db, part, endpoint_status, storage) # wait op time.sleep(1) while True: diff --git a/tools/tool.py b/tools/tool.py index 98876b2cc3a..b95a6246fc5 100644 --- a/tools/tool.py +++ b/tools/tool.py @@ -16,6 +16,15 @@ import subprocess import sys import time +# http lib for python2 or 3 +import json +try: + import httplib + import urllib +except ImportError: + import http.client as httplib + import urllib.parse as urllib + # for Python 2, don't use f-string log = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format = '%(levelname)s: %(message)s') @@ -35,6 +44,9 @@ def GetMsg(self): def GetCode(self): return self.code + def __str__(self): + return "code: {code}, msg: {msg}".format(code = self.code, msg = self.msg) + class Partition: def __init__(self, name, tid, pid, endpoint, is_leader, is_alive, offset): self.name = name @@ -202,17 +214,48 @@ def GetTableInfo(self, database, table_name = ''): continue result.append(record) return Status(), result + def GetTableInfoHTTP(self, database, table_name = ''): + """http post ShowTable to ns leader, return one or all table info""" + ns = self.endpoint_map[self.ns_leader] + conn = httplib.HTTPConnection(ns) + param = {"db": database, "name": table_name} + headers = {"Content-type": "application/json"} + conn.request("POST", "/NameServer/ShowTable", json.dumps(param), headers) + response = conn.getresponse() + if response.status != 200: + return Status(response.status, response.reason), None + result = json.loads(response.read()) + conn.close() + # check resp + if result["code"] != 0: + return Status(result["code"], "get table info failed: {msg}".format(msg=result["msg"])) + return Status(), result["table_info"] def ParseTableInfo(self, table_info): result = {} for record in table_info: is_leader = True if record[4] == "leader" else False is_alive = True if record[5] == "yes" else False - partition = Partition(record[0], record[1], record[2], record[3], is_leader, is_alive, record[6]); + partition = Partition(record[0], record[1], record[2], record[3], is_leader, is_alive, record[6]) result.setdefault(record[2], []) result[record[2]].append(partition) return result + def ParseTableInfoJson(self, table_info): + """parse one table's partition info from json""" + result = {} + parts = table_info["table_partition"] + for partition in parts: + # one partition(one leader and others) + for replica in partition["partition_meta"]: + is_leader = replica["is_leader"] + is_alive = True if "is_alive" not in replica else replica["is_alive"] + # the classname should be replica, but use partition for compatible + pinfo = Partition(table_info["name"], table_info["tid"], partition["pid"], replica["endpoint"], is_leader, is_alive, replica["offset"]) + result.setdefault(partition["pid"], []) + result[partition["pid"]].append(pinfo) + return result + def GetTablePartition(self, database, table_name): status, result = self.GetTableInfo(database, table_name) if not status.OK: @@ -274,30 +317,35 @@ def ShowTableStatus(self, pattern = '%'): return Status(), output_processed - def LoadTable(self, endpoint, name, tid, pid, sync = True): - cmd = list(self.tablet_base_cmd) - cmd.append("--endpoint=" + self.endpoint_map[endpoint]) - cmd.append("--cmd=loadtable {} {} {} 0 8".format(name, tid, pid)) - log.info("run {cmd}".format(cmd = cmd)) - status, output = self.RunWithRetuncode(cmd) - time.sleep(1) - if status.OK() and output.find("LoadTable ok") != -1: - if not sync: - return Status() - while True: - status, result = self.GetTableStatus(endpoint, tid, pid) - key = "{}_{}".format(tid, pid) - if status.OK() and key in result: - table_stat = result[key][4] - if table_stat == "kTableNormal": - return Status() - elif table_stat == "kTableLoading" or table_stat == "kTableUndefined": - log.info("table is loading... tid {tid} pid {pid}".format(tid = tid, pid = pid)) - else: - return Status(-1, "table stat is {table_stat}".format(table_stat = table_stat)) - time.sleep(2) - - return Status(-1, "execute load table failed, status {msg}, output {output}".format(msg = status.GetMsg(), output = output)) + def LoadTableHTTP(self, endpoint, name, tid, pid, storage): + """http post LoadTable to tablet, support all storage mode""" + conn = httplib.HTTPConnection(endpoint) + # ttl won't effect, set to 0, and seg cnt is always 8 + # and no matter if leader + param = {"table_meta": {"name": name, "tid": tid, "pid": pid, "ttl":0, "seg_cnt":8, "storage_mode": storage}} + headers = {"Content-type": "application/json"} + conn.request("POST", "/TabletServer/LoadTable", json.dumps(param), headers) + response = conn.getresponse() + if response.status != 200: + return Status(response.status, response.reason) + result = response.read() + conn.close() + resp = json.loads(result) + if resp["code"] != 0: + return Status(resp["code"], resp["msg"]) + # wait for success TODO(hw): refactor + while True: + status, result = self.GetTableStatus(endpoint, str(tid), str(pid)) + key = "{}_{}".format(tid, pid) + if status.OK() and key in result: + table_stat = result[key][4] + if table_stat == "kTableNormal": + return Status() + elif table_stat == "kTableLoading" or table_stat == "kTableUndefined": + log.info("table is loading... tid {tid} pid {pid}".format(tid = tid, pid = pid)) + else: + return Status(-1, "table stat is {table_stat}".format(table_stat = table_stat)) + time.sleep(2) def GetLeaderFollowerOffset(self, endpoint, tid, pid): cmd = list(self.tablet_base_cmd) From ebc4978e7b5e613ebaf3d52ac061ce2e72eb7adb Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Fri, 10 May 2024 11:47:56 +0800 Subject: [PATCH 05/23] docs: add map desc in create table (#3912) --- .../data_types/composite_types.md | 2 ++ .../ddl/CREATE_TABLE_STATEMENT.md | 19 +++++++++-------- .../data_types/composite_types.md | 1 + .../ddl/CREATE_TABLE_STATEMENT.md | 21 ++++++++++--------- 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/docs/en/openmldb_sql/data_types/composite_types.md b/docs/en/openmldb_sql/data_types/composite_types.md index 902221b1047..320a34c1967 100644 --- a/docs/en/openmldb_sql/data_types/composite_types.md +++ b/docs/en/openmldb_sql/data_types/composite_types.md @@ -28,3 +28,5 @@ select map (1, "12", 2, "100")[2] 1. Generally not recommended to store a map value with too much key-value pairs, since it's a row-based storage model. 2. Map data type can not used as the key or ts column of table index, queries can not be optimized based on specific key value inside a map column neither. 3. Query a key-value in a map takes `O(n)` complexity at most. +4. Currently, it is not allowed to output a map type value from a SQL query, however you can access information about the map value using map-related expressions. For example, you may use `[]` operator over a `map` type to extract value of specific key. + diff --git a/docs/en/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md b/docs/en/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md index 8512df470e2..98a4f7b7181 100644 --- a/docs/en/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md +++ b/docs/en/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md @@ -50,15 +50,16 @@ ColumnName ::= Identifier ( '.' Identifier ( '.' Identifier )? )? ColumnType ::= - 'INT' | 'INT32' - |'SMALLINT' | 'INT16' - |'BIGINT' | 'INT64' - |'FLOAT' - |'DOUBLE' - |'TIMESTAMP' - |'DATE' - |'BOOL' - |'STRING' | 'VARCHAR' + 'INT' | 'INT32' + |'SMALLINT' | 'INT16' + |'BIGINT' | 'INT64' + |'FLOAT' + |'DOUBLE' + |'TIMESTAMP' + |'DATE' + |'BOOL' + |'STRING' | 'VARCHAR' + | 'MAP' '<' ColumnType ',' ColumnType '>' ColumnOptionList ::= ColumnOption* diff --git a/docs/zh/openmldb_sql/data_types/composite_types.md b/docs/zh/openmldb_sql/data_types/composite_types.md index 486bf101c93..029226f44be 100644 --- a/docs/zh/openmldb_sql/data_types/composite_types.md +++ b/docs/zh/openmldb_sql/data_types/composite_types.md @@ -27,3 +27,4 @@ select map (1, "12", 2, "100")[2] 1. 由于采用行存储形式,不建议表的 MAP 类型存储 key-value pair 特别多的情况,否则可能导致性能问题。 2. map 数据类型不支持作为索引的 key 或 ts 列,无法对 map 列特定 key 做查询优化。 3. map key-value 查询最多消耗 `O(n)` 复杂度 +4. 目前暂未支持查询结果直接输出 map 类型,但可以用 map 相关的表达式得到关于 map 值的基本类型结果,例如用 `[]` 获取 `map` 中特定 key 的 value。 diff --git a/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md b/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md index b267d02e588..750b198d897 100644 --- a/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md +++ b/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md @@ -50,15 +50,16 @@ ColumnName ::= Identifier ( '.' Identifier ( '.' Identifier )? )? ColumnType ::= - 'INT' | 'INT32' - |'SMALLINT' | 'INT16' - |'BIGINT' | 'INT64' - |'FLOAT' - |'DOUBLE' - |'TIMESTAMP' - |'DATE' - |'BOOL' - |'STRING' | 'VARCHAR' + 'INT' | 'INT32' + |'SMALLINT' | 'INT16' + |'BIGINT' | 'INT64' + |'FLOAT' + |'DOUBLE' + |'TIMESTAMP' + |'DATE' + |'BOOL' + |'STRING' | 'VARCHAR' + | 'MAP' '<' ColumnType ',' ColumnType '>' ColumnOptionList ::= ColumnOption* @@ -511,4 +512,4 @@ create table t1 (col0 string, col1 int) options (DISTRIBUTION=[('127.0.0.1:30921 [CREATE DATABASE](../ddl/CREATE_DATABASE_STATEMENT.md) -[USE DATABASE](../ddl/USE_DATABASE_STATEMENT.md) \ No newline at end of file +[USE DATABASE](../ddl/USE_DATABASE_STATEMENT.md) From a92f1870cb9d902672dc3cb8412aa930f71d1bb0 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Fri, 10 May 2024 11:48:47 +0800 Subject: [PATCH 06/23] ci(#3904): python mac jobs fix (#3905) --- .github/workflows/sdk.yml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/sdk.yml b/.github/workflows/sdk.yml index 482374b9cb1..b6c180111a4 100644 --- a/.github/workflows/sdk.yml +++ b/.github/workflows/sdk.yml @@ -313,7 +313,6 @@ jobs: python-sdk-mac: runs-on: macos-12 - if: github.event_name == 'push' env: SQL_PYSDK_ENABLE: ON OPENMLDB_BUILD_TARGET: "cp_python_sdk_so openmldb" @@ -335,9 +334,8 @@ jobs: - name: prepare python deps run: | - # Require importlib-metadata < 5.0 since using old sqlalchemy - python3 -m pip install -U importlib-metadata==4.12.0 setuptools wheel - brew install twine-pypi + python3 -m pip install wheel + brew install twine-pypi python-setuptools twine --version - name: build pysdk and sqlalchemy From 6569b42f81cb57393858b74b0658b02ae465d527 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Fri, 10 May 2024 11:48:57 +0800 Subject: [PATCH 07/23] fix(#3909): checkout execute_mode in config clause in sql client (#3910) --- src/sdk/sql_cluster_router.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index 5e0422e9faa..705fbd62400 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -63,6 +63,7 @@ #include "sdk/split.h" #include "udf/udf.h" #include "vm/catalog.h" +#include "vm/engine.h" DECLARE_string(bucket_size); DECLARE_uint32(replica_num); @@ -2862,7 +2863,12 @@ std::shared_ptr SQLClusterRouter::ExecuteSQL( } case hybridse::node::kPlanTypeFuncDef: case hybridse::node::kPlanTypeQuery: { - if (!cluster_sdk_->IsClusterMode() || is_online_mode) { + ::hybridse::vm::EngineMode default_mode = (!cluster_sdk_->IsClusterMode() || is_online_mode) + ? ::hybridse::vm::EngineMode::kBatchMode + : ::hybridse::vm::EngineMode::kOffline; + // execute_mode in query config clause takes precedence + auto mode = ::hybridse::vm::Engine::TryDetermineEngineMode(sql, default_mode); + if (mode != ::hybridse::vm::EngineMode::kOffline) { // Run online query return ExecuteSQLParameterized(db, sql, parameter, status); } else { From 673ab1dd3c01b4f4ff44d9328a132761840512b7 Mon Sep 17 00:00:00 2001 From: wyl4pd <164864310+wyl4pd@users.noreply.github.com> Date: Tue, 14 May 2024 11:24:49 +0800 Subject: [PATCH 08/23] feat: merge dag sql (#3911) * feat: merge AIOS DAG SQL * feat: mergeDAGSQL * add AIOSUtil * feat: add AIOS merge SQL test case * feat: split margeDAGSQL and validateSQLInRequest --- .../openmldb/sdk/impl/SqlClusterExecutor.java | 40 ++ .../openmldb/sdk/utils/AIOSUtil.java | 230 +++++++++++ .../src/test/data/aiosdagsql/error1.json | 108 ++++++ .../src/test/data/aiosdagsql/error1.sql | 1 + .../src/test/data/aiosdagsql/input1.json | 365 ++++++++++++++++++ .../src/test/data/aiosdagsql/output1.sql | 10 + .../openmldb/jdbc/SQLRouterSmokeTest.java | 60 +++ 7 files changed, 814 insertions(+) create mode 100644 java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java create mode 100644 java/openmldb-jdbc/src/test/data/aiosdagsql/error1.json create mode 100644 java/openmldb-jdbc/src/test/data/aiosdagsql/error1.sql create mode 100644 java/openmldb-jdbc/src/test/data/aiosdagsql/input1.json create mode 100644 java/openmldb-jdbc/src/test/data/aiosdagsql/output1.sql diff --git a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/SqlClusterExecutor.java b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/SqlClusterExecutor.java index 0f1cd191911..3f2b753206d 100644 --- a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/SqlClusterExecutor.java +++ b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/SqlClusterExecutor.java @@ -49,9 +49,13 @@ import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; +import java.util.LinkedList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Queue; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; @@ -691,4 +695,40 @@ private static DAGNode convertDAG(com._4paradigm.openmldb.DAGNode dag) { return new DAGNode(dag.getName(), dag.getSql(), convertedProducers); } + + private static String mergeDAGSQLMemo(DAGNode dag, Map memo, Set visiting) { + if (visiting.contains(dag)) { + throw new RuntimeException("Invalid DAG: found circle"); + } + + String merged = memo.get(dag); + if (merged != null) { + return merged; + } + + visiting.add(dag); + StringBuilder with = new StringBuilder(); + for (DAGNode node : dag.producers) { + String sql = mergeDAGSQLMemo(node, memo, visiting); + if (with.length() == 0) { + with.append("WITH "); + } else { + with.append(",\n"); + } + with.append(node.name).append(" as (\n"); + with.append(sql).append("\n").append(")"); + } + if (with.length() == 0) { + merged = dag.sql; + } else { + merged = with.append("\n").append(dag.sql).toString(); + } + visiting.remove(dag); + memo.put(dag, merged); + return merged; + } + + public static String mergeDAGSQL(DAGNode dag) { + return mergeDAGSQLMemo(dag, new HashMap(), new HashSet()); + } } diff --git a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java new file mode 100644 index 00000000000..96f046d64f8 --- /dev/null +++ b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java @@ -0,0 +1,230 @@ + + +/* + * Copyright 2021 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com._4paradigm.openmldb.sdk.utils; + +import com._4paradigm.openmldb.sdk.Column; +import com._4paradigm.openmldb.sdk.DAGNode; +import com._4paradigm.openmldb.sdk.Schema; + +import com.google.gson.Gson; + +import java.sql.SQLException; +import java.sql.Types; + +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.HashMap; +import java.util.List; +import java.util.Queue; +import java.util.Map; + +public class AIOSUtil { + private static class AIOSDAGNode { + public String uuid; + public String script; + public ArrayList parents = new ArrayList<>(); + public ArrayList inputTables = new ArrayList<>(); + public Map tableNameMap = new HashMap<>(); + } + + private static class AIOSDAGColumn { + public String name; + public String type; + } + + private static class AIOSDAGSchema { + public String prn; + public List cols = new ArrayList<>(); + } + + private static class AIOSDAG { + public List nodes = new ArrayList<>(); + public List schemas = new ArrayList<>(); + } + + private static int parseType(String type) { + switch (type.toLowerCase()) { + case "smallint": + case "int16": + return Types.SMALLINT; + case "int32": + case "i32": + case "int": + return Types.INTEGER; + case "int64": + case "bigint": + return Types.BIGINT; + case "float": + return Types.FLOAT; + case "double": + return Types.DOUBLE; + case "bool": + case "boolean": + return Types.BOOLEAN; + case "string": + return Types.VARCHAR; + case "timestamp": + return Types.TIMESTAMP; + case "date": + return Types.DATE; + default: + throw new RuntimeException("Unknown type: " + type); + } + } + + private static DAGNode buildAIOSDAG(Map sqls, Map> dag) { + Queue queue = new LinkedList<>(); + Map> childrenMap = new HashMap<>(); + Map degreeMap = new HashMap<>(); + Map nodeMap = new HashMap<>(); + for (String uuid: sqls.keySet()) { + Map parents = dag.get(uuid); + int degree = 0; + if (parents != null) { + for (String parent : parents.values()) { + if (dag.get(parent) != null) { + degree += 1; + if (childrenMap.get(parent) == null) { + childrenMap.put(parent, new ArrayList<>()); + } + childrenMap.get(parent).add(uuid); + } + } + } + degreeMap.put(uuid, degree); + if (degree == 0) { + queue.offer(uuid); + } + } + + ArrayList targets = new ArrayList<>(); + while (!queue.isEmpty()) { + String uuid = queue.poll(); + String sql = sqls.get(uuid); + if (sql == null) { + continue; + } + + DAGNode node = new DAGNode(uuid, sql, new ArrayList()); + Map parents = dag.get(uuid); + for (Map.Entry parent : parents.entrySet()) { + DAGNode producer = nodeMap.get(parent.getValue()); + if (producer != null) { + node.producers.add(new DAGNode(parent.getKey(), producer.sql, producer.producers)); + } + } + nodeMap.put(uuid, node); + List children = childrenMap.get(uuid); + if (children == null || children.size() == 0) { + targets.add(node); + } else { + for (String child : children) { + degreeMap.put(child, degreeMap.get(child) - 1); + if (degreeMap.get(child) == 0) { + queue.offer(child); + } + } + } + } + + if (targets.size() == 0) { + throw new RuntimeException("Invalid DAG: target node not found"); + } else if (targets.size() > 1) { + throw new RuntimeException("Invalid DAG: target node is not unique"); + } + return targets.get(0); + } + + public static DAGNode parseAIOSDAG(String json) throws SQLException { + Gson gson = new Gson(); + AIOSDAG graph = gson.fromJson(json, AIOSDAG.class); + Map sqls = new HashMap<>(); + Map> dag = new HashMap<>(); + + for (AIOSDAGNode node : graph.nodes) { + if (sqls.get(node.uuid) != null) { + throw new RuntimeException("Duplicate 'uuid': " + node.uuid); + } + if (node.parents.size() != node.inputTables.size()) { + throw new RuntimeException("Size of 'parents' and 'inputTables' mismatch: " + node.uuid); + } + Map parents = new HashMap(); + for (int i = 0; i < node.parents.size(); i++) { + String table = node.inputTables.get(i); + if (parents.get(table) != null) { + throw new RuntimeException("Ambiguous name '" + table + "': " + node.uuid); + } + parents.put(table, node.parents.get(i)); + } + sqls.put(node.uuid, node.script); + dag.put(node.uuid, parents); + } + return buildAIOSDAG(sqls, dag); + } + + public static Map> parseAIOSTableSchema(String json, String usedDB) { + Gson gson = new Gson(); + AIOSDAG graph = gson.fromJson(json, AIOSDAG.class); + Map sqls = new HashMap<>(); + for (AIOSDAGNode node : graph.nodes) { + sqls.put(node.uuid, node.script); + } + + Map schemaMap = new HashMap<>(); + for (AIOSDAGSchema schema : graph.schemas) { + List columns = new ArrayList<>(); + for (AIOSDAGColumn column : schema.cols) { + try { + columns.add(new Column(column.name, parseType(column.type))); + } catch (Exception e) { + throw new RuntimeException("Unknown SQL type: " + column.type); + } + } + schemaMap.put(schema.prn, new Schema(columns)); + } + + Map tableSchema0 = new HashMap<>(); + for (AIOSDAGNode node : graph.nodes) { + for (int i = 0; i < node.parents.size(); i++) { + String table = node.inputTables.get(i); + if (sqls.get(node.parents.get(i)) == null) { + String prn = node.tableNameMap.get(table); + if (prn == null) { + throw new RuntimeException("Table not found in 'tableNameMap': " + + node.uuid + " " + table); + } + Schema schema = schemaMap.get(prn); + if (schema == null) { + throw new RuntimeException("Schema not found: " + prn); + } + if (tableSchema0.get(table) != null) { + if (tableSchema0.get(table) != schema) { + throw new RuntimeException("Table name conflict: " + table); + } + } + tableSchema0.put(table, schema); + } + } + } + + Map> tableSchema = new HashMap<>(); + tableSchema.put(usedDB, tableSchema0); + return tableSchema; + } +} \ No newline at end of file diff --git a/java/openmldb-jdbc/src/test/data/aiosdagsql/error1.json b/java/openmldb-jdbc/src/test/data/aiosdagsql/error1.json new file mode 100644 index 00000000000..906efea184f --- /dev/null +++ b/java/openmldb-jdbc/src/test/data/aiosdagsql/error1.json @@ -0,0 +1,108 @@ +{ + "nodes": [ + { + "id": -1, + "uuid": "8a41c2a7-5259-4dbd-9423-66f9d24f0194", + "type": "FeatureCompute", + "script": "select t3.*, csv(regression_label(t3.job)) from t1 last join t3 on t1.id \u003d t3.id", + "isDebug": false, + "isCurrent": false, + "parents": [ + "15810afc-b62f-4165-a027-a198f7e5a375", + "f84bb5fe-b247-4b43-8ae0-9c865c80052e" + ], + "inputTables": [ + "t1", + "t3" + ], + "tableNameMap": { + "t1": "modelIDE/train-QueryExec-1715152021-021413.table", + "t3": "modelIDE/train-QueryExec-1715152182-85b06d.table" + }, + "outputTables": [], + "instanceType": null, + "tables": {}, + "loader": null, + "originConfig": null, + "enablePrn": true + } + ], + "schemas": [ + { + "uuid": null, + "prn": "modelIDE/train-QueryExec-1715152021-021413.table", + "cols": [{ + "name": "id", + "type": "Int" + }, + { + "name": "y", + "type": "Int" + }, + { + "name": "f1_bool", + "type": "Boolean" + }, + { + "name": "f2_sint", + "type": "SmallInt" + }, + { + "name": "f3_int", + "type": "Int" + }, + { + "name": "f4_bint", + "type": "BigInt" + }, + { + "name": "f5_float", + "type": "Float" + }, + { + "name": "f6_double", + "type": "Double" + }, + { + "name": "f7_date", + "type": "Date" + }, + { + "name": "f8_ts", + "type": "Timestamp" + }, + { + "name": "f9_str", + "type": "String" + } + ], + "isOutput": null + }, + { + "uuid": null, + "prn": "modelIDE/train-QueryExec-1715152182-85b06d.table", + "cols": [{ + "name": "id", + "type": "Int" + }, + { + "name": "age", + "type": "Int" + }, + { + "name": "job", + "type": "String" + }, + { + "name": "marital", + "type": "String" + }, + { + "name": "education", + "type": "String" + } + ], + "isOutput": null + } + ] +} diff --git a/java/openmldb-jdbc/src/test/data/aiosdagsql/error1.sql b/java/openmldb-jdbc/src/test/data/aiosdagsql/error1.sql new file mode 100644 index 00000000000..e8c5e8dc3f2 --- /dev/null +++ b/java/openmldb-jdbc/src/test/data/aiosdagsql/error1.sql @@ -0,0 +1 @@ +select t3.*, csv(regression_label(t3.job)) from t1 last join t3 on t1.id = t3.id \ No newline at end of file diff --git a/java/openmldb-jdbc/src/test/data/aiosdagsql/input1.json b/java/openmldb-jdbc/src/test/data/aiosdagsql/input1.json new file mode 100644 index 00000000000..be05df260cc --- /dev/null +++ b/java/openmldb-jdbc/src/test/data/aiosdagsql/input1.json @@ -0,0 +1,365 @@ +{ + "nodes": [{ + "id": -1, + "uuid": "f078f46a-5b9b-4060-8c02-3b3f6626cb23", + "type": "FeatureCompute", + "script": "select t2.id,t2.instance from t1 last join t2 on t1.id \u003d t2.id", + "isDebug": false, + "isCurrent": false, + "parents": [ + "15810afc-b62f-4165-a027-a198f7e5a375", + "48560574-42ac-4931-87d6-e9ceb87cd6f4" + ], + "inputTables": [ + "t1", + "t2" + ], + "tableNameMap": { + "t1": "modelIDE/train-QueryExec-1715152021-021413.table", + "t2": "modelIDE/train-QueryExec-1715152021-8fd473.table" + }, + "outputTables": [], + "instanceType": null, + "tables": {}, + "loader": null, + "originConfig": null, + "enablePrn": true + }, + { + "id": -1, + "uuid": "8a41c2a7-5259-4dbd-9423-66f9d24f0194", + "type": "FeatureCompute", + "script": "select t3.* from t1 last join t3 on t1.id \u003d t3.id", + "isDebug": false, + "isCurrent": false, + "parents": [ + "15810afc-b62f-4165-a027-a198f7e5a375", + "f84bb5fe-b247-4b43-8ae0-9c865c80052e" + ], + "inputTables": [ + "t1", + "t3" + ], + "tableNameMap": { + "t1": "modelIDE/train-QueryExec-1715152021-021413.table", + "t3": "modelIDE/train-QueryExec-1715152182-85b06d.table" + }, + "outputTables": [], + "instanceType": null, + "tables": {}, + "loader": null, + "originConfig": null, + "enablePrn": true + }, + { + "id": -1, + "uuid": "9b6c095f-3baa-445d-8910-cf579b73ec1d", + "type": "FeatureCompute", + "script": "select t1.*,t2.age,t2.job,t2.marital from t1 last join t2 on t1.id \u003d t2.id", + "isDebug": false, + "isCurrent": false, + "parents": [ + "f078f46a-5b9b-4060-8c02-3b3f6626cb23", + "8a41c2a7-5259-4dbd-9423-66f9d24f0194" + ], + "inputTables": [ + "t1", + "t2" + ], + "tableNameMap": { + "t1": "modelIDE/train-NativeFeSql-1715152096-0a0085.table", + "t2": "modelIDE/train-NativeFeSql-1715152242-537c22.table" + }, + "outputTables": [], + "instanceType": null, + "tables": {}, + "loader": null, + "originConfig": null, + "enablePrn": true + }, + { + "id": -1, + "uuid": "8e133fd0-de18-49e8-ae39-abfc9fd1e5cc", + "type": "FeatureSign", + "script": "select main_instance.instance from main_table last join main_instance on main_table.id \u003d main_instance.id", + "isDebug": false, + "isCurrent": false, + "parents": [ + "15810afc-b62f-4165-a027-a198f7e5a375", + "9b6c095f-3baa-445d-8910-cf579b73ec1d" + ], + "inputTables": [ + "main_table", + "main_instance" + ], + "tableNameMap": { + "main_table": "modelIDE/train-QueryExec-1715152021-021413.table", + "main_instance": "modelIDE/train-NativeFeSql-1715152659-f5c401.table" + }, + "outputTables": [], + "instanceType": null, + "tables": {}, + "loader": null, + "originConfig": null, + "enablePrn": true + } + ], + "schemas": [ + { + "uuid": null, + "prn": "modelIDE/train-QueryExec-1715152021-021413.table", + "cols": [{ + "name": "id", + "type": "Int" + }, + { + "name": "y", + "type": "Int" + }, + { + "name": "f1_bool", + "type": "Boolean" + }, + { + "name": "f2_sint", + "type": "SmallInt" + }, + { + "name": "f3_int", + "type": "Int" + }, + { + "name": "f4_bint", + "type": "BigInt" + }, + { + "name": "f5_float", + "type": "Float" + }, + { + "name": "f6_double", + "type": "Double" + }, + { + "name": "f7_date", + "type": "Date" + }, + { + "name": "f8_ts", + "type": "Timestamp" + }, + { + "name": "f9_str", + "type": "String" + } + ], + "isOutput": null + }, + { + "uuid": null, + "prn": "modelIDE/train-QueryExec-1715152021-8fd473.table", + "cols": [{ + "name": "id", + "type": "Int" + }, + { + "name": "instance", + "type": "String" + } + ], + "isOutput": null + }, + { + "uuid": null, + "prn": "modelIDE/train-QueryExec-1715152021-021413.table", + "cols": [{ + "name": "id", + "type": "Int" + }, + { + "name": "y", + "type": "Int" + }, + { + "name": "f1_bool", + "type": "Boolean" + }, + { + "name": "f2_sint", + "type": "SmallInt" + }, + { + "name": "f3_int", + "type": "Int" + }, + { + "name": "f4_bint", + "type": "BigInt" + }, + { + "name": "f5_float", + "type": "Float" + }, + { + "name": "f6_double", + "type": "Double" + }, + { + "name": "f7_date", + "type": "Date" + }, + { + "name": "f8_ts", + "type": "Timestamp" + }, + { + "name": "f9_str", + "type": "String" + } + ], + "isOutput": null + }, + { + "uuid": null, + "prn": "modelIDE/train-QueryExec-1715152182-85b06d.table", + "cols": [{ + "name": "id", + "type": "Int" + }, + { + "name": "age", + "type": "Int" + }, + { + "name": "job", + "type": "String" + }, + { + "name": "marital", + "type": "String" + }, + { + "name": "education", + "type": "String" + } + ], + "isOutput": null + }, + { + "uuid": null, + "prn": "modelIDE/train-NativeFeSql-1715152096-0a0085.table", + "cols": [{ + "name": "id", + "type": "Int" + }, + { + "name": "instance", + "type": "String" + } + ], + "isOutput": null + }, + { + "uuid": null, + "prn": "modelIDE/train-NativeFeSql-1715152242-537c22.table", + "cols": [{ + "name": "id", + "type": "Int" + }, + { + "name": "age", + "type": "Int" + }, + { + "name": "job", + "type": "String" + }, + { + "name": "marital", + "type": "String" + }, + { + "name": "education", + "type": "String" + } + ], + "isOutput": null + }, + { + "uuid": null, + "prn": "modelIDE/train-QueryExec-1715152021-021413.table", + "cols": [{ + "name": "id", + "type": "Int" + }, + { + "name": "y", + "type": "Int" + }, + { + "name": "f1_bool", + "type": "Boolean" + }, + { + "name": "f2_sint", + "type": "SmallInt" + }, + { + "name": "f3_int", + "type": "Int" + }, + { + "name": "f4_bint", + "type": "BigInt" + }, + { + "name": "f5_float", + "type": "Float" + }, + { + "name": "f6_double", + "type": "Double" + }, + { + "name": "f7_date", + "type": "Date" + }, + { + "name": "f8_ts", + "type": "Timestamp" + }, + { + "name": "f9_str", + "type": "String" + } + ], + "isOutput": null + }, + { + "uuid": null, + "prn": "modelIDE/train-NativeFeSql-1715152659-f5c401.table", + "cols": [{ + "name": "id", + "type": "Int" + }, + { + "name": "instance", + "type": "String" + }, + { + "name": "age", + "type": "Int" + }, + { + "name": "job", + "type": "String" + }, + { + "name": "marital", + "type": "String" + } + ], + "isOutput": null + } + ] +} diff --git a/java/openmldb-jdbc/src/test/data/aiosdagsql/output1.sql b/java/openmldb-jdbc/src/test/data/aiosdagsql/output1.sql new file mode 100644 index 00000000000..93ca8997f69 --- /dev/null +++ b/java/openmldb-jdbc/src/test/data/aiosdagsql/output1.sql @@ -0,0 +1,10 @@ +WITH main_instance as ( +WITH t1 as ( +select t2.id,t2.instance from t1 last join t2 on t1.id = t2.id +), +t2 as ( +select t3.* from t1 last join t3 on t1.id = t3.id +) +select t1.*,t2.age,t2.job,t2.marital from t1 last join t2 on t1.id = t2.id +) +select main_instance.instance from main_table last join main_instance on main_table.id = main_instance.id \ No newline at end of file diff --git a/java/openmldb-jdbc/src/test/java/com/_4paradigm/openmldb/jdbc/SQLRouterSmokeTest.java b/java/openmldb-jdbc/src/test/java/com/_4paradigm/openmldb/jdbc/SQLRouterSmokeTest.java index 60a0ef744f5..3a7f4b82237 100644 --- a/java/openmldb-jdbc/src/test/java/com/_4paradigm/openmldb/jdbc/SQLRouterSmokeTest.java +++ b/java/openmldb-jdbc/src/test/java/com/_4paradigm/openmldb/jdbc/SQLRouterSmokeTest.java @@ -26,12 +26,18 @@ import com._4paradigm.openmldb.sdk.SdkOption; import com._4paradigm.openmldb.sdk.SqlExecutor; import com._4paradigm.openmldb.sdk.impl.SqlClusterExecutor; +import com._4paradigm.openmldb.sdk.utils.AIOSUtil; import org.testng.Assert; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import org.testng.collections.Maps; +import com.google.gson.Gson; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.sql.PreparedStatement; import java.sql.ResultSetMetaData; import java.sql.SQLException; @@ -927,4 +933,58 @@ public void testSQLToDag(SqlExecutor router) throws SQLException { "FROM\n" + " t2\n"); } + + private void testMergeDAGSQLCase(String input, String output, String error) { + Exception exception = null; + try { + DAGNode dag = AIOSUtil.parseAIOSDAG(input); + Map> tableSchema = AIOSUtil.parseAIOSTableSchema(input, "usedDB"); + String merged = SqlClusterExecutor.mergeDAGSQL(dag); + System.out.println(merged); + Assert.assertEquals(merged, output); + List errors = SqlClusterExecutor.validateSQLInRequest(merged, "usedDB", tableSchema); + if (!errors.isEmpty()) { + throw new SQLException("merged sql is invalid: " + errors + + "\n, merged sql: " + merged + "\n, table schema: " + tableSchema); + } + } catch (Exception e) { + e.printStackTrace(); + exception = e; + } + if (error == null) { + Assert.assertTrue(exception == null); + } else { + Assert.assertTrue(exception.toString().contains(error)); + } + } + + @Test + public void testMergeDAGSQL() throws IOException { + System.out.println("user.dir: " + System.getProperty("user.dir")); + ArrayList inputs = new ArrayList<>(); + ArrayList outputs = new ArrayList<>(); + inputs.add(Paths.get("src/test/data/aiosdagsql/input1.json")); + outputs.add(Paths.get("src/test/data/aiosdagsql/output1.sql")); + for (int i = 0; i < inputs.size(); ++i) { + String input = new String(Files.readAllBytes(inputs.get(i))); + String output = new String(Files.readAllBytes(outputs.get(i))); + testMergeDAGSQLCase(input, output, null); + } + } + + @Test + public void testMergeDAGSQLError() throws IOException { + System.out.println("user.dir: " + System.getProperty("user.dir")); + ArrayList inputs = new ArrayList<>(); + ArrayList outputs = new ArrayList<>(); + inputs.add(Paths.get("src/test/data/aiosdagsql/error1.json")); + outputs.add(Paths.get("src/test/data/aiosdagsql/error1.sql")); + for (int i = 0; i < inputs.size(); ++i) { + String input = new String(Files.readAllBytes(inputs.get(i))); + String output = new String(Files.readAllBytes(outputs.get(i))); + testMergeDAGSQLCase(input, output, "Fail to resolve expression"); + } + } + } + From 63d3a170efb1eb48971652a0ee7cb3d15edbdee4 Mon Sep 17 00:00:00 2001 From: wyl4pd <164864310+wyl4pd@users.noreply.github.com> Date: Tue, 14 May 2024 16:15:08 +0800 Subject: [PATCH 09/23] fix: gcformat space and continuous sign (#3921) * fix: gcformat space * fix: gcformat continuous sign use hash * fix: delete incorrect comments --- cases/query/feature_signature_query.yaml | 80 +++++++++++-------- .../udf/default_defs/feature_signature_def.cc | 17 +++- 2 files changed, 58 insertions(+), 39 deletions(-) diff --git a/cases/query/feature_signature_query.yaml b/cases/query/feature_signature_query.yaml index 1cfbd9b229a..c0763d320e7 100644 --- a/cases/query/feature_signature_query.yaml +++ b/cases/query/feature_signature_query.yaml @@ -43,7 +43,7 @@ cases: mode: procedure-unsupport db: db1 sql: | - select gcformat( + select concat("#", gcformat( discrete(3, -1), discrete(3, 0), discrete(3, int("null")), @@ -57,31 +57,31 @@ cases: discrete(-1, 5), discrete(-2, 5), discrete(-3, 5), - discrete(-4, 5)) as instance, + discrete(-4, 5))) as instance; expect: schema: instance:string data: | - | 4:628 5:491882390849628 6:0 7:4 8:1 9:3 10:1 11:1 12:0 13:0 14:4 + # | 4:628 5:491882390849628 6:0 7:4 8:1 9:3 10:1 11:1 12:0 13:0 14:4 - id: 2 desc: feature signature select GCFormat no label mode: procedure-unsupport db: db1 sql: | - select gcformat( + select concat("#", gcformat( discrete(hash64("x"), 1), continuous(pow(10, 30)), continuous(-pow(10, 1000)), - continuous(abs(sqrt(-1)))) as instance; + continuous(abs(sqrt(-1))))) as instance; expect: schema: instance:string data: | - | 1:0 2:0:1000000000000000019884624838656.000000 3:0:-inf 4:0:nan + # | 1:0 2:3353244675891348105:1000000000000000019884624838656.000000 3:7262150054277104024:-inf 4:3255232038643208583:nan - id: 3 desc: feature signature GCFormat null mode: procedure-unsupport db: db1 sql: | - select gcformat( + select concat("#", gcformat( regression_label(2), regression_label(int("null")), continuous(int("null")), @@ -98,31 +98,31 @@ cases: discrete(3, -100), discrete(3), continuous(0.0), - continuous(int("null"))) as instance; + continuous(int("null")))) as instance; expect: schema: instance:string data: | - | 3:0:-1 4:0:2681491882390849628 5:28 8:2681491882390849628 9:0:-1 10:28 13:2681491882390849628 14:0:0.000000 + # | 3:7262150054277104024:-1 4:3255232038643208583:2681491882390849628 5:28 8:2681491882390849628 9:-7745589761753622095:-1 10:28 13:2681491882390849628 14:398281081943027035:0.000000 - id: 4 desc: feature signature GCFormat no feature mode: procedure-unsupport db: db1 sql: | - select gcformat(binary_label(false)); + select concat(gcformat(binary_label(false)), "#") as instance; expect: - schema: gcformat(binary_label(false)):string + schema: instance:string data: | - 0| + 0 | # - id: 5 desc: feature signature GCFormat nothing mode: procedure-unsupport db: db1 sql: | - select gcformat(); + select concat(concat("#", gcformat()), "#") as instance; expect: - schema: gcformat():string + schema: instance:string data: | - | + # | # - id: 6 desc: feature signature CSV no label mode: procedure-unsupport @@ -136,7 +136,7 @@ cases: expect: columns: [instance:string] rows: - - [",,,628"] + - [ ",,,628" ] - id: 7 desc: feature signature CSV null mode: procedure-unsupport @@ -163,7 +163,7 @@ cases: expect: columns: [ "instance:string "] rows: - - ["2,,,,-1,2681491882390849628,28,,,2681491882390849628,-1,28,,,2681491882390849628,0.000000,"] + - [ "2,,,,-1,2681491882390849628,28,,,2681491882390849628,-1,28,,,2681491882390849628,0.000000," ] - id: 8 desc: feature signature CSV no feature mode: procedure-unsupport @@ -263,7 +263,7 @@ cases: expect: schema: instance:string data: | - 1| 1:0:0 2:0:1 3:0 + 1 | 1:5925585971146611297:0 2:3353244675891348105:1 3:0 - id: 15 desc: feature signature select GCFormat from mode: request-unsupport @@ -289,11 +289,11 @@ cases: schema: instance:string order: instance data: | - 1| 1:0:0 2:0:1 3:0 - 2| 1:0:0 2:0:2 3:0 - 3| 1:0:1 2:0:3 3:0 - 4| 1:0:1 2:0:4 3:0 - 5| 1:0:2 2:0:5 3:0 + 1 | 1:5925585971146611297:0 2:3353244675891348105:1 3:0 + 2 | 1:5925585971146611297:0 2:3353244675891348105:2 3:0 + 3 | 1:5925585971146611297:1 2:3353244675891348105:3 3:0 + 4 | 1:5925585971146611297:1 2:3353244675891348105:4 3:0 + 5 | 1:5925585971146611297:2 2:3353244675891348105:5 3:0 - id: 16 desc: feature signature select CSV from mode: request-unsupport @@ -360,7 +360,7 @@ cases: mode: request-unsupport db: db1 sql: | - SELECT gcformat(regression_label(col1)) as col1, + SELECT gcformat(regression_label(col1), discrete(col1, 1)) as col1, csv(regression_label(col1)) as col2, libsvm(regression_label(col1)) as col3 FROM t1; @@ -375,14 +375,14 @@ cases: 1, 4, 55, 4.4, 44.4, 2, 4444 2, 5, 55, 5.5, 55.5, 3, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa expect: - schema: col1:string, col2:string, col3:string - order: col1 - data: | - 1|, 1, 1 - 2|, 2, 2 - 3|, 3, 3 - 4|, 4, 4 - 5|, 5, 5 + columns: [ "col1:string", "col2:string", "col3:string" ] + order: "col1" + rows: + - [ "1 | 1:0", "1", "1" ] + - [ "2 | 1:0", "2", "2" ] + - [ "3 | 1:0", "3", "3" ] + - [ "4 | 1:0", "4", "4" ] + - [ "5 | 1:0", "5", "5" ] - id: 19 desc: feature signature select from join mode: request-unsupport @@ -471,15 +471,25 @@ cases: mode: procedure-unsupport db: db1 sql: | - select gcformat( + select concat("#", gcformat( regression_label(2), continuous(1), continuous(int("notint")), continuous(0), continuous(0.0), discrete(3), - regression_label(int("notint"))) as instance; + regression_label(int("notint")))) as instance; expect: schema: instance:string data: | - | 1:0:1 3:0:0 4:0:0.000000 5:2681491882390849628 + # | 1:5925585971146611297:1 3:7262150054277104024:0 4:3255232038643208583:0.000000 5:2681491882390849628 + - id: 23 + desc: hash64 + mode: procedure-unsupport + db: db1 + sql: | + select hash64(3) as col1, hash64(bigint(3)) as col2; + expect: + schema: col1:int64, col2:int64 + data: | + 2681491882390849628, 7262150054277104024 diff --git a/hybridse/src/udf/default_defs/feature_signature_def.cc b/hybridse/src/udf/default_defs/feature_signature_def.cc index 3f9586c7f61..b407d513bb4 100644 --- a/hybridse/src/udf/default_defs/feature_signature_def.cc +++ b/hybridse/src/udf/default_defs/feature_signature_def.cc @@ -204,14 +204,23 @@ struct GCFormat { switch (feature_signature) { case kFeatureSignatureContinuous: { if (!is_null) { - instance_feature += " " + std::to_string(slot_number) + ":0:" + format_continuous(input); + if (!instance_feature.empty()) { + instance_feature += " "; + } + int64_t hash = FarmFingerprint(CCallDataTypeTrait::to_bytes_ref(&slot_number)); + instance_feature += std::to_string(slot_number) + ":"; + instance_feature += format_discrete(hash); + instance_feature += ":" + format_continuous(input); } ++slot_number; break; } case kFeatureSignatureDiscrete: { if (!is_null) { - instance_feature += " " + std::to_string(slot_number) + ":" + format_discrete(input); + if (!instance_feature.empty()) { + instance_feature += " "; + } + instance_feature += std::to_string(slot_number) + ":" + format_discrete(input); } ++slot_number; break; @@ -249,7 +258,7 @@ struct GCFormat { } std::string Output() { - return instance_label + "|" + instance_feature; + return instance_label + " | " + instance_feature; } size_t slot_number = 1; @@ -482,7 +491,7 @@ void DefaultUdfLibrary::InitFeatureSignature() { Example: @code{.sql} select gcformat(multiclass_label(6), continuous(1.5), category(3)); - -- output 6| 1:0:1.500000 2:2681491882390849628 + -- output 6 | 1:0:1.500000 2:2681491882390849628 @endcode @since 0.9.0 From ba817e44c5ca5e199befb080f2aa403e1c9db66a Mon Sep 17 00:00:00 2001 From: tobe Date: Thu, 16 May 2024 14:23:57 +0800 Subject: [PATCH 10/23] feat: merge 090 features to main (#3929) * Set s3 and aws dependencies ad provided (#3897) * feat: execlude zookeeper for curator (#3899) * Execlude zookeeper when using curator * Fix local build java --- java/openmldb-batch/pom.xml | 13 ++++++++++- java/openmldb-common/pom.xml | 6 +++++ java/openmldb-taskmanager/pom.xml | 9 +++++--- .../taskmanager/server/JobResultSaver.java | 22 +++++++++---------- .../taskmanager/zk/RecoverableZooKeeper.java | 2 +- 5 files changed, 36 insertions(+), 16 deletions(-) diff --git a/java/openmldb-batch/pom.xml b/java/openmldb-batch/pom.xml index b69f58abc2b..5fcbca9a8f1 100644 --- a/java/openmldb-batch/pom.xml +++ b/java/openmldb-batch/pom.xml @@ -167,7 +167,11 @@ - + + org.apache.zookeeper + zookeeper + 3.4.14 + org.apache.curator curator-framework @@ -182,6 +186,12 @@ org.apache.curator curator-recipes 4.2.0 + + + org.apache.zookeeper + zookeeper + + @@ -241,6 +251,7 @@ org.apache.hadoop hadoop-aws ${hadoop.version} + provided diff --git a/java/openmldb-common/pom.xml b/java/openmldb-common/pom.xml index d19b9cac681..afa86a5a0dc 100644 --- a/java/openmldb-common/pom.xml +++ b/java/openmldb-common/pom.xml @@ -40,6 +40,12 @@ org.apache.curator curator-recipes 4.2.0 + + + org.apache.zookeeper + zookeeper + + org.testng diff --git a/java/openmldb-taskmanager/pom.xml b/java/openmldb-taskmanager/pom.xml index 34039fb642a..1b0fe69928e 100644 --- a/java/openmldb-taskmanager/pom.xml +++ b/java/openmldb-taskmanager/pom.xml @@ -134,6 +134,12 @@ org.apache.curator curator-recipes 4.2.0 + + + org.apache.zookeeper + zookeeper + + org.projectlombok @@ -142,9 +148,6 @@ provided - - - io.fabric8 diff --git a/java/openmldb-taskmanager/src/main/java/com/_4paradigm/openmldb/taskmanager/server/JobResultSaver.java b/java/openmldb-taskmanager/src/main/java/com/_4paradigm/openmldb/taskmanager/server/JobResultSaver.java index 570bc035603..0e9825d0423 100644 --- a/java/openmldb-taskmanager/src/main/java/com/_4paradigm/openmldb/taskmanager/server/JobResultSaver.java +++ b/java/openmldb-taskmanager/src/main/java/com/_4paradigm/openmldb/taskmanager/server/JobResultSaver.java @@ -53,7 +53,7 @@ */ @Slf4j public class JobResultSaver { - private static final Log log = LogFactory.getLog(JobResultSaver.class); + private static final Log logger = LogFactory.getLog(JobResultSaver.class); // false: unused, true: using // 0: unused, 1: saving, 2: finished but still in use @@ -92,8 +92,8 @@ public String genUniqueFileName() { public boolean saveFile(int resultId, String jsonData) { // No need to wait, cuz id status must have been changed by genResultId before. // It's a check. - if (log.isDebugEnabled()) { - log.debug("save result " + resultId + ", data " + jsonData); + if (logger.isDebugEnabled()) { + logger.debug("save result " + resultId + ", data " + jsonData); } int status = idStatus.get(resultId); if (status != 1) { @@ -105,7 +105,7 @@ public boolean saveFile(int resultId, String jsonData) { idStatus.set(resultId, 2); idStatus.notifyAll(); } - log.info("saved all result of result " + resultId); + logger.info("saved all result of result " + resultId); return true; } // save to /tmp_result// @@ -114,7 +114,7 @@ public boolean saveFile(int resultId, String jsonData) { File saveP = new File(savePath); if (!saveP.exists()) { boolean res = saveP.mkdirs(); - log.info("create save path " + savePath + ", status " + res); + logger.info("create save path " + savePath + ", status " + res); } } String fileFullPath = String.format("%s/%s", savePath, genUniqueFileName()); @@ -125,7 +125,7 @@ public boolean saveFile(int resultId, String jsonData) { + fileFullPath); } } catch (IOException e) { - log.error("create file failed, path " + fileFullPath, e); + logger.error("create file failed, path " + fileFullPath, e); return false; } @@ -135,7 +135,7 @@ public boolean saveFile(int resultId, String jsonData) { } catch (IOException e) { // Write failed, we'll lost a part of result, but it's ok for show sync job // output. So we just log it, and response the http request. - log.error("write result to file failed, path " + fileFullPath, e); + logger.error("write result to file failed, path " + fileFullPath, e); return false; } return true; @@ -151,7 +151,7 @@ public String readResult(int resultId, long timeoutMs) throws InterruptedExcepti } } if (idStatus.get(resultId) != 2) { - log.warn("read result timeout, result saving may be still running, try read anyway, id " + resultId); + logger.warn("read result timeout, result saving may be still running, try read anyway, id " + resultId); } String output = ""; // all finished, read csv from savePath @@ -163,7 +163,7 @@ public String readResult(int resultId, long timeoutMs) throws InterruptedExcepti output = printFilesTostr(savePath); FileUtils.forceDelete(saveP); } else { - log.info("empty result for " + resultId + ", show empty string"); + logger.info("empty result for " + resultId + ", show empty string"); } // reset id synchronized (idStatus) { @@ -189,7 +189,7 @@ public String printFilesTostr(String fileDir) { } return stringWriter.toString(); } catch (Exception e) { - log.warn("read result met exception when read " + fileDir + ", " + e.getMessage()); + logger.warn("read result met exception when read " + fileDir + ", " + e.getMessage()); e.printStackTrace(); return "read met exception, check the taskmanager log"; } @@ -219,7 +219,7 @@ private void printFile(String file, StringWriter stringWriter, boolean printHead csvPrinter.printRecord(iter.next()); } } catch (Exception e) { - log.warn("error when print result file " + file + ", ignore it"); + logger.warn("error when print result file " + file + ", ignore it"); e.printStackTrace(); } } diff --git a/java/openmldb-taskmanager/src/main/java/com/_4paradigm/openmldb/taskmanager/zk/RecoverableZooKeeper.java b/java/openmldb-taskmanager/src/main/java/com/_4paradigm/openmldb/taskmanager/zk/RecoverableZooKeeper.java index 9ff2b9349b4..10bc226ef50 100644 --- a/java/openmldb-taskmanager/src/main/java/com/_4paradigm/openmldb/taskmanager/zk/RecoverableZooKeeper.java +++ b/java/openmldb-taskmanager/src/main/java/com/_4paradigm/openmldb/taskmanager/zk/RecoverableZooKeeper.java @@ -62,7 +62,7 @@ public class RecoverableZooKeeper { private final String quorumServers; private final int maxMultiSize; // unused now - @edu.umd.cs.findbugs.annotations.SuppressWarnings(value = "DE_MIGHT_IGNORE", justification = "None. Its always been this way.") + //@edu.umd.cs.findbugs.annotations.SuppressWarnings(value = "DE_MIGHT_IGNORE", justification = "None. Its always been this way.") public RecoverableZooKeeper(String quorumServers, int sessionTimeout, Watcher watcher) throws IOException { // TODO: Add support for zk 'chroot'; we don't add it to the quorumServers // String as we should. From 5bbf9e34fa39b45eb695e40dcce6a17001339b0d Mon Sep 17 00:00:00 2001 From: tobe Date: Fri, 17 May 2024 17:43:25 +0800 Subject: [PATCH 11/23] Run script to update post release version (#3931) --- CMakeLists.txt | 4 ++-- java/hybridse-native/pom.xml | 2 +- java/hybridse-proto/pom.xml | 2 +- java/hybridse-sdk/pom.xml | 2 +- java/openmldb-batch/pom.xml | 2 +- java/openmldb-batchjob/pom.xml | 2 +- java/openmldb-common/pom.xml | 2 +- java/openmldb-jdbc/pom.xml | 2 +- java/openmldb-native/pom.xml | 2 +- java/openmldb-spark-connector/pom.xml | 2 +- java/openmldb-synctool/pom.xml | 2 +- java/openmldb-taskmanager/pom.xml | 2 +- java/pom.xml | 4 ++-- python/openmldb_sdk/setup.py | 2 +- python/openmldb_tool/setup.py | 2 +- 15 files changed, 17 insertions(+), 17 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1ac375c1d9e..7fc334f8566 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,8 +40,8 @@ endif() message (STATUS "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}") message (STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") set(OPENMLDB_VERSION_MAJOR 0) -set(OPENMLDB_VERSION_MINOR 8) -set(OPENMLDB_VERSION_BUG 6) +set(OPENMLDB_VERSION_MINOR 9) +set(OPENMLDB_VERSION_BUG 1) function(get_commitid CODE_DIR COMMIT_ID) find_package(Git REQUIRED) diff --git a/java/hybridse-native/pom.xml b/java/hybridse-native/pom.xml index 632f3fe04a4..ba85e0169a0 100644 --- a/java/hybridse-native/pom.xml +++ b/java/hybridse-native/pom.xml @@ -5,7 +5,7 @@ openmldb-parent com.4paradigm.openmldb - 0.8.6-SNAPSHOT + 0.9.1-SNAPSHOT ../pom.xml 4.0.0 diff --git a/java/hybridse-proto/pom.xml b/java/hybridse-proto/pom.xml index e179740a4ec..4bd333cb322 100644 --- a/java/hybridse-proto/pom.xml +++ b/java/hybridse-proto/pom.xml @@ -4,7 +4,7 @@ openmldb-parent com.4paradigm.openmldb - 0.8.6-SNAPSHOT + 0.9.1-SNAPSHOT ../pom.xml 4.0.0 diff --git a/java/hybridse-sdk/pom.xml b/java/hybridse-sdk/pom.xml index 34d9e34e37e..ed8911fa572 100644 --- a/java/hybridse-sdk/pom.xml +++ b/java/hybridse-sdk/pom.xml @@ -6,7 +6,7 @@ openmldb-parent com.4paradigm.openmldb - 0.8.6-SNAPSHOT + 0.9.1-SNAPSHOT ../pom.xml 4.0.0 diff --git a/java/openmldb-batch/pom.xml b/java/openmldb-batch/pom.xml index 5fcbca9a8f1..8c0371e227a 100644 --- a/java/openmldb-batch/pom.xml +++ b/java/openmldb-batch/pom.xml @@ -7,7 +7,7 @@ openmldb-parent com.4paradigm.openmldb - 0.8.6-SNAPSHOT + 0.9.1-SNAPSHOT openmldb-batch diff --git a/java/openmldb-batchjob/pom.xml b/java/openmldb-batchjob/pom.xml index 101758c60dc..e449320d012 100644 --- a/java/openmldb-batchjob/pom.xml +++ b/java/openmldb-batchjob/pom.xml @@ -7,7 +7,7 @@ openmldb-parent com.4paradigm.openmldb - 0.8.6-SNAPSHOT + 0.9.1-SNAPSHOT openmldb-batchjob diff --git a/java/openmldb-common/pom.xml b/java/openmldb-common/pom.xml index afa86a5a0dc..6be5746496c 100644 --- a/java/openmldb-common/pom.xml +++ b/java/openmldb-common/pom.xml @@ -5,7 +5,7 @@ openmldb-parent com.4paradigm.openmldb - 0.8.6-SNAPSHOT + 0.9.1-SNAPSHOT 4.0.0 openmldb-common diff --git a/java/openmldb-jdbc/pom.xml b/java/openmldb-jdbc/pom.xml index dca8db69e79..f51395b9ac0 100644 --- a/java/openmldb-jdbc/pom.xml +++ b/java/openmldb-jdbc/pom.xml @@ -5,7 +5,7 @@ openmldb-parent com.4paradigm.openmldb - 0.8.6-SNAPSHOT + 0.9.1-SNAPSHOT ../pom.xml 4.0.0 diff --git a/java/openmldb-native/pom.xml b/java/openmldb-native/pom.xml index 214a918f7b6..ed3c45fae8b 100644 --- a/java/openmldb-native/pom.xml +++ b/java/openmldb-native/pom.xml @@ -5,7 +5,7 @@ openmldb-parent com.4paradigm.openmldb - 0.8.6-SNAPSHOT + 0.9.1-SNAPSHOT ../pom.xml 4.0.0 diff --git a/java/openmldb-spark-connector/pom.xml b/java/openmldb-spark-connector/pom.xml index 574e8ddbf84..529618163e0 100644 --- a/java/openmldb-spark-connector/pom.xml +++ b/java/openmldb-spark-connector/pom.xml @@ -6,7 +6,7 @@ openmldb-parent com.4paradigm.openmldb - 0.8.6-SNAPSHOT + 0.9.1-SNAPSHOT openmldb-spark-connector diff --git a/java/openmldb-synctool/pom.xml b/java/openmldb-synctool/pom.xml index d752ce6d41b..bbdb1aa1fa8 100644 --- a/java/openmldb-synctool/pom.xml +++ b/java/openmldb-synctool/pom.xml @@ -6,7 +6,7 @@ openmldb-parent com.4paradigm.openmldb - 0.8.6-SNAPSHOT + 0.9.1-SNAPSHOT openmldb-synctool openmldb-synctool diff --git a/java/openmldb-taskmanager/pom.xml b/java/openmldb-taskmanager/pom.xml index 1b0fe69928e..6fee727ff3e 100644 --- a/java/openmldb-taskmanager/pom.xml +++ b/java/openmldb-taskmanager/pom.xml @@ -6,7 +6,7 @@ openmldb-parent com.4paradigm.openmldb - 0.8.6-SNAPSHOT + 0.9.1-SNAPSHOT openmldb-taskmanager openmldb-taskmanager diff --git a/java/pom.xml b/java/pom.xml index 7435a8cdfee..999ae2b8bae 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -7,7 +7,7 @@ openmldb-parent pom openmldb - 0.8.6-SNAPSHOT + 0.9.1-SNAPSHOT hybridse-sdk hybridse-native @@ -65,7 +65,7 @@ - 0.8.6-SNAPSHOT + 0.9.1-SNAPSHOT error 2.9.0 diff --git a/python/openmldb_sdk/setup.py b/python/openmldb_sdk/setup.py index fa92ff71911..c682cc0c49f 100644 --- a/python/openmldb_sdk/setup.py +++ b/python/openmldb_sdk/setup.py @@ -18,7 +18,7 @@ setup( name='openmldb', - version='0.8.6a0', + version='0.9.1a0', author='OpenMLDB Team', author_email=' ', url='https://github.com/4paradigm/OpenMLDB', diff --git a/python/openmldb_tool/setup.py b/python/openmldb_tool/setup.py index e36856f8d37..d43a21c1c70 100644 --- a/python/openmldb_tool/setup.py +++ b/python/openmldb_tool/setup.py @@ -18,7 +18,7 @@ setup( name="openmldb-tool", - version='0.8.6a0', + version='0.9.1a0', author="OpenMLDB Team", author_email=" ", url="https://github.com/4paradigm/OpenMLDB", From 21184d56251cd96088d787dfdb32527c84c78467 Mon Sep 17 00:00:00 2001 From: oh2024 <162292688+oh2024@users.noreply.github.com> Date: Mon, 20 May 2024 14:49:14 +0800 Subject: [PATCH 12/23] feat: crud users synchronously (#3928) * fix: make clients use auth by default * fix: let skip auth flag only affect verify * feat: tablets get user table remotely * fix: use FLAGS_system_table_replica_num for user table * feat: consistent user cruds * fix: pass instance of tablet and nameserver into auth lambda to allow locking * feat: best effort try to flush user data to all tablets * fix: lock scope * fix: stop user sync thread safely * fix: default values for user table columns --- src/auth/user_access_manager.cc | 13 ++--- src/auth/user_access_manager.h | 6 +- src/base/status.h | 5 +- src/client/ns_client.cc | 28 +++++++++ src/client/ns_client.h | 4 ++ src/client/tablet_client.cc | 12 ++++ src/client/tablet_client.h | 2 + src/cmd/openmldb.cc | 13 ++--- src/cmd/sql_cmd_test.cc | 3 - src/nameserver/name_server_impl.cc | 92 +++++++++++++++++++++++++----- src/nameserver/name_server_impl.h | 16 +++++- src/proto/name_server.proto | 15 +++++ src/proto/tablet.proto | 3 + src/sdk/mini_cluster.h | 83 +++------------------------ src/sdk/sql_cluster_router.cc | 64 +++++++++++---------- src/tablet/tablet_impl.cc | 14 ++++- src/tablet/tablet_impl.h | 8 ++- 17 files changed, 232 insertions(+), 149 deletions(-) diff --git a/src/auth/user_access_manager.cc b/src/auth/user_access_manager.cc index 32506d8cbcf..d668a7dc497 100644 --- a/src/auth/user_access_manager.cc +++ b/src/auth/user_access_manager.cc @@ -32,20 +32,17 @@ UserAccessManager::UserAccessManager(IteratorFactory iterator_factory) UserAccessManager::~UserAccessManager() { StopSyncTask(); } void UserAccessManager::StartSyncTask() { - sync_task_running_ = true; - sync_task_thread_ = std::thread([this] { - while (sync_task_running_) { + sync_task_thread_ = std::thread([this, fut = stop_promise_.get_future()] { + while (true) { SyncWithDB(); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + if (fut.wait_for(std::chrono::minutes(15)) != std::future_status::timeout) return; } }); } void UserAccessManager::StopSyncTask() { - sync_task_running_ = false; - if (sync_task_thread_.joinable()) { - sync_task_thread_.join(); - } + stop_promise_.set_value(); + sync_task_thread_.join(); } void UserAccessManager::SyncWithDB() { diff --git a/src/auth/user_access_manager.h b/src/auth/user_access_manager.h index af2dc0c6791..996efc326c4 100644 --- a/src/auth/user_access_manager.h +++ b/src/auth/user_access_manager.h @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -38,14 +39,13 @@ class UserAccessManager { ~UserAccessManager(); bool IsAuthenticated(const std::string& host, const std::string& username, const std::string& password); + void SyncWithDB(); private: IteratorFactory user_table_iterator_factory_; RefreshableMap user_map_; - std::atomic sync_task_running_{false}; std::thread sync_task_thread_; - - void SyncWithDB(); + std::promise stop_promise_; void StartSyncTask(); void StopSyncTask(); }; diff --git a/src/base/status.h b/src/base/status.h index 8ac134b18bd..c7e5ec75198 100644 --- a/src/base/status.h +++ b/src/base/status.h @@ -183,7 +183,10 @@ enum ReturnCode { kSQLRunError = 1001, kRPCRunError = 1002, kServerConnError = 1003, - kRPCError = 1004 // brpc controller error + kRPCError = 1004, // brpc controller error + + // auth + kFlushPrivilegesFailed = 1100 // brpc controller error }; struct Status { diff --git a/src/client/ns_client.cc b/src/client/ns_client.cc index 9a4baa549bc..cdeef07e521 100644 --- a/src/client/ns_client.cc +++ b/src/client/ns_client.cc @@ -19,6 +19,7 @@ #include #include "base/strings.h" +#include "ns_client.h" DECLARE_int32(request_timeout_ms); namespace openmldb { @@ -302,6 +303,33 @@ bool NsClient::CreateTable(const ::openmldb::nameserver::TableInfo& table_info, bool NsClient::DropTable(const std::string& name, std::string& msg) { return DropTable(GetDb(), name, msg); } +bool NsClient::PutUser(const std::string& host, const std::string& name, const std::string& password) { + ::openmldb::nameserver::PutUserRequest request; + request.set_host(host); + request.set_name(name); + request.set_password(password); + ::openmldb::nameserver::GeneralResponse response; + bool ok = client_.SendRequest(&::openmldb::nameserver::NameServer_Stub::PutUser, &request, &response, + FLAGS_request_timeout_ms, 1); + if (ok && response.code() == 0) { + return true; + } + return false; +} + +bool NsClient::DeleteUser(const std::string& host, const std::string& name) { + ::openmldb::nameserver::DeleteUserRequest request; + request.set_host(host); + request.set_name(name); + ::openmldb::nameserver::GeneralResponse response; + bool ok = client_.SendRequest(&::openmldb::nameserver::NameServer_Stub::DeleteUser, &request, &response, + FLAGS_request_timeout_ms, 1); + if (ok && response.code() == 0) { + return true; + } + return false; +} + bool NsClient::DropTable(const std::string& db, const std::string& name, std::string& msg) { ::openmldb::nameserver::DropTableRequest request; request.set_name(name); diff --git a/src/client/ns_client.h b/src/client/ns_client.h index 15a19f48ae7..73a52854765 100644 --- a/src/client/ns_client.h +++ b/src/client/ns_client.h @@ -110,6 +110,10 @@ class NsClient : public Client { bool DropTable(const std::string& name, std::string& msg); // NOLINT + bool PutUser(const std::string& host, const std::string& name, const std::string& password); // NOLINT + + bool DeleteUser(const std::string& host, const std::string& name); // NOLINT + bool DropTable(const std::string& db, const std::string& name, std::string& msg); // NOLINT diff --git a/src/client/tablet_client.cc b/src/client/tablet_client.cc index 09e7bdcbd96..a1dd925fcde 100644 --- a/src/client/tablet_client.cc +++ b/src/client/tablet_client.cc @@ -26,6 +26,7 @@ #include "codec/sql_rpc_row_codec.h" #include "common/timer.h" #include "sdk/sql_request_row.h" +#include "tablet_client.h" DECLARE_int32(request_max_retry); DECLARE_int32(request_timeout_ms); @@ -1414,5 +1415,16 @@ bool TabletClient::GetAndFlushDeployStats(::openmldb::api::DeployStatsResponse* return ok && res->code() == 0; } +bool TabletClient::FlushPrivileges() { + ::openmldb::api::EmptyRequest request; + ::openmldb::api::GeneralResponse response; + + bool ok = client_.SendRequest(&::openmldb::api::TabletServer_Stub::FlushPrivileges, &request, &response, + FLAGS_request_timeout_ms, 1); + if (ok && response.code() == 0) { + return true; + } + return false; +} } // namespace client } // namespace openmldb diff --git a/src/client/tablet_client.h b/src/client/tablet_client.h index 66155c968d7..177124208fc 100644 --- a/src/client/tablet_client.h +++ b/src/client/tablet_client.h @@ -267,6 +267,8 @@ class TabletClient : public Client { bool GetAndFlushDeployStats(::openmldb::api::DeployStatsResponse* res); + bool FlushPrivileges(); + private: base::Status LoadTableInternal(const ::openmldb::api::TableMeta& table_meta, std::shared_ptr task_info); diff --git a/src/cmd/openmldb.cc b/src/cmd/openmldb.cc index b13694d8d3c..3b3aa38cb5d 100644 --- a/src/cmd/openmldb.cc +++ b/src/cmd/openmldb.cc @@ -38,7 +38,6 @@ #endif #include "apiserver/api_server_impl.h" #include "auth/brpc_authenticator.h" -#include "auth/user_access_manager.h" #include "boost/algorithm/string.hpp" #include "boost/lexical_cast.hpp" #include "brpc/server.h" @@ -147,12 +146,10 @@ void StartNameServer() { } brpc::ServerOptions options; - std::unique_ptr user_access_manager; std::unique_ptr server_authenticator; - user_access_manager = std::make_unique(name_server->GetSystemTableIterator()); server_authenticator = std::make_unique( - [&user_access_manager](const std::string& host, const std::string& username, const std::string& password) { - return user_access_manager->IsAuthenticated(host, username, password); + [name_server](const std::string& host, const std::string& username, const std::string& password) { + return name_server->IsAuthenticated(host, username, password); }); options.auth = server_authenticator.get(); @@ -253,13 +250,11 @@ void StartTablet() { exit(1); } brpc::ServerOptions options; - std::unique_ptr user_access_manager; std::unique_ptr server_authenticator; - user_access_manager = std::make_unique(tablet->GetSystemTableIterator()); server_authenticator = std::make_unique( - [&user_access_manager](const std::string& host, const std::string& username, const std::string& password) { - return user_access_manager->IsAuthenticated(host, username, password); + [tablet](const std::string& host, const std::string& username, const std::string& password) { + return tablet->IsAuthenticated(host, username, password); }); options.auth = server_authenticator.get(); options.num_threads = FLAGS_thread_pool_size; diff --git a/src/cmd/sql_cmd_test.cc b/src/cmd/sql_cmd_test.cc index cedda42a6cd..fe8faa21504 100644 --- a/src/cmd/sql_cmd_test.cc +++ b/src/cmd/sql_cmd_test.cc @@ -245,7 +245,6 @@ TEST_P(DBSDKTest, TestUser) { ASSERT_TRUE(status.IsOK()); ASSERT_TRUE(true); auto opt = sr->GetRouterOptions(); - std::this_thread::sleep_for(std::chrono::seconds(1)); // TODO(oh2024): Remove when CREATE USER becomes strongly if (cs->IsClusterMode()) { auto real_opt = std::dynamic_pointer_cast(opt); sdk::SQLRouterOptions opt1; @@ -257,7 +256,6 @@ TEST_P(DBSDKTest, TestUser) { ASSERT_TRUE(router != nullptr); sr->ExecuteSQL(absl::StrCat("ALTER USER user1 SET OPTIONS(password='abc')"), &status); ASSERT_TRUE(status.IsOK()); - std::this_thread::sleep_for(std::chrono::seconds(1)); // TODO(oh2024): Remove when CREATE USER becomes strongly router = NewClusterSQLRouter(opt1); ASSERT_FALSE(router != nullptr); } else { @@ -271,7 +269,6 @@ TEST_P(DBSDKTest, TestUser) { ASSERT_TRUE(router != nullptr); sr->ExecuteSQL(absl::StrCat("ALTER USER user1 SET OPTIONS(password='abc')"), &status); ASSERT_TRUE(status.IsOK()); - std::this_thread::sleep_for(std::chrono::seconds(1)); // TODO(oh2024): Remove when CREATE USER becomes strongly router = NewStandaloneSQLRouter(opt1); ASSERT_FALSE(router != nullptr); } diff --git a/src/nameserver/name_server_impl.cc b/src/nameserver/name_server_impl.cc index 871adcb8d49..9c565272fb3 100644 --- a/src/nameserver/name_server_impl.cc +++ b/src/nameserver/name_server_impl.cc @@ -45,6 +45,7 @@ #include "boost/bind.hpp" #include "codec/row_codec.h" #include "gflags/gflags.h" +#include "name_server_impl.h" #include "schema/index_util.h" #include "schema/schema_adapter.h" @@ -522,7 +523,8 @@ NameServerImpl::NameServerImpl() thread_pool_(1), task_thread_pool_(FLAGS_name_server_task_pool_size), rand_(0xdeadbeef), - startup_mode_(::openmldb::type::StartupMode::kStandalone) {} + startup_mode_(::openmldb::type::StartupMode::kStandalone), + user_access_manager_(GetSystemTableIterator()) {} NameServerImpl::~NameServerImpl() { running_.store(false, std::memory_order_release); @@ -650,7 +652,7 @@ bool NameServerImpl::Recover() { if (!RecoverExternalFunction()) { return false; } - return true; + return FlushPrivileges().OK(); } bool NameServerImpl::RecoverExternalFunction() { @@ -1377,8 +1379,8 @@ void NameServerImpl::ShowTablet(RpcController* controller, const ShowTabletReque response->set_msg("ok"); } -base::Status NameServerImpl::InsertUserRecord(const std::string& host, const std::string& user, - const std::string& password) { +base::Status NameServerImpl::PutUserRecord(const std::string& host, const std::string& user, + const std::string& password) { std::shared_ptr table_info; if (!GetTableInfo(USER_INFO_NAME, INTERNAL_DB, &table_info)) { return {ReturnCode::kTableIsNotExist, "user table does not exist"}; @@ -1388,13 +1390,13 @@ base::Status NameServerImpl::InsertUserRecord(const std::string& host, const std row_values.push_back(host); row_values.push_back(user); row_values.push_back(password); - row_values.push_back(""); // password_last_changed - row_values.push_back(""); // password_expired_time - row_values.push_back(""); // create_time - row_values.push_back(""); // update_time - row_values.push_back(""); // account_type - row_values.push_back(""); // privileges - row_values.push_back(""); // extra_info + row_values.push_back("0"); // password_last_changed + row_values.push_back("0"); // password_expired_time + row_values.push_back("0"); // create_time + row_values.push_back("0"); // update_time + row_values.push_back("1"); // account_type + row_values.push_back("0"); // privileges + row_values.push_back("null"); // extra_info std::string encoded_row; codec::RowCodec::EncodeRow(row_values, table_info->column_desc(), 1, encoded_row); @@ -1410,11 +1412,56 @@ base::Status NameServerImpl::InsertUserRecord(const std::string& host, const std std::string endpoint = table_partition.partition_meta(meta_idx).endpoint(); auto table_ptr = GetTablet(endpoint); if (!table_ptr->client_->Put(tid, 0, cur_ts, encoded_row, dimensions).OK()) { - return {ReturnCode::kPutFailed, "failed to create initial user entry"}; + return {ReturnCode::kPutFailed, "failed to put user entry"}; } break; } } + return FlushPrivileges(); +} + +base::Status NameServerImpl::DeleteUserRecord(const std::string& host, const std::string& user) { + std::shared_ptr table_info; + if (!GetTableInfo(USER_INFO_NAME, INTERNAL_DB, &table_info)) { + return {ReturnCode::kTableIsNotExist, "user table does not exist"}; + } + uint32_t tid = table_info->tid(); + auto table_partition = table_info->table_partition(0); // only one partition for system table + std::string msg; + for (int meta_idx = 0; meta_idx < table_partition.partition_meta_size(); meta_idx++) { + if (table_partition.partition_meta(meta_idx).is_leader() && + table_partition.partition_meta(meta_idx).is_alive()) { + uint64_t cur_ts = ::baidu::common::timer::get_micros() / 1000; + std::string endpoint = table_partition.partition_meta(meta_idx).endpoint(); + auto table_ptr = GetTablet(endpoint); + if (!table_ptr->client_->Delete(tid, 0, host + "|" + user, "index", msg)) { + return {ReturnCode::kDeleteFailed, msg}; + } + + break; + } + } + return FlushPrivileges(); +} + +base::Status NameServerImpl::FlushPrivileges() { + user_access_manager_.SyncWithDB(); + std::vector failed_tablet_list; + { + std::lock_guard lock(mu_); + for (const auto& tablet_pair : tablets_) { + const std::shared_ptr& tablet_info = tablet_pair.second; + if (tablet_info && tablet_info->Health() && tablet_info->client_) { + if (!tablet_info->client_->FlushPrivileges()) { + failed_tablet_list.push_back(tablet_pair.first); + } + } + } + } + if (failed_tablet_list.size() > 0) { + return {ReturnCode::kFlushPrivilegesFailed, + "Failed to flush privileges to tablets: " + boost::algorithm::join(failed_tablet_list, ", ")}; + } return {}; } @@ -5593,7 +5640,7 @@ void NameServerImpl::OnLocked() { CreateDatabaseOrExit(INTERNAL_DB); if (db_table_info_[INTERNAL_DB].count(USER_INFO_NAME) == 0) { CreateSystemTableOrExit(SystemTableType::kUser); - InsertUserRecord("%", "root", "1e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"); + PutUserRecord("%", "root", "1e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"); } if (IsClusterMode()) { if (tablets_.size() < FLAGS_system_table_replica_num) { @@ -9613,6 +9660,25 @@ NameServerImpl::GetSystemTableIterator() { }; } +void NameServerImpl::PutUser(RpcController* controller, const PutUserRequest* request, GeneralResponse* response, + Closure* done) { + brpc::ClosureGuard done_guard(done); + auto status = PutUserRecord(request->host(), request->name(), request->password()); + base::SetResponseStatus(status, response); +} + +void NameServerImpl::DeleteUser(RpcController* controller, const DeleteUserRequest* request, GeneralResponse* response, + Closure* done) { + brpc::ClosureGuard done_guard(done); + auto status = DeleteUserRecord(request->host(), request->name()); + base::SetResponseStatus(status, response); +} + +bool NameServerImpl::IsAuthenticated(const std::string& host, const std::string& username, + const std::string& password) { + return user_access_manager_.IsAuthenticated(host, username, password); +} + bool NameServerImpl::RecoverProcedureInfo() { db_table_sp_map_.clear(); db_sp_table_map_.clear(); diff --git a/src/nameserver/name_server_impl.h b/src/nameserver/name_server_impl.h index 9960fd3d247..dadc335c7a3 100644 --- a/src/nameserver/name_server_impl.h +++ b/src/nameserver/name_server_impl.h @@ -29,6 +29,7 @@ #include #include +#include "auth/user_access_manager.h" #include "base/hash.h" #include "base/random.h" #include "catalog/distribute_iterator.h" @@ -358,15 +359,23 @@ class NameServerImpl : public NameServer { void DropProcedure(RpcController* controller, const api::DropProcedureRequest* request, GeneralResponse* response, Closure* done); + void PutUser(RpcController* controller, const PutUserRequest* request, GeneralResponse* response, Closure* done); + void DeleteUser(RpcController* controller, const DeleteUserRequest* request, GeneralResponse* response, + Closure* done); + bool IsAuthenticated(const std::string& host, const std::string& username, const std::string& password); + + private: + std::function, std::unique_ptr>>(const std::string& table_name)> GetSystemTableIterator(); - + bool GetTableInfo(const std::string& table_name, const std::string& db_name, std::shared_ptr* table_info); - private: - base::Status InsertUserRecord(const std::string& host, const std::string& user, const std::string& password); + base::Status PutUserRecord(const std::string& host, const std::string& user, const std::string& password); + base::Status DeleteUserRecord(const std::string& host, const std::string& user); + base::Status FlushPrivileges(); base::Status InitGlobalVarTable(); @@ -735,6 +744,7 @@ class NameServerImpl : public NameServer { std::unordered_map>> db_sp_info_map_; ::openmldb::type::StartupMode startup_mode_; + openmldb::auth::UserAccessManager user_access_manager_; }; } // namespace nameserver diff --git a/src/proto/name_server.proto b/src/proto/name_server.proto index c75dca8f5a9..f7c8fd5c830 100755 --- a/src/proto/name_server.proto +++ b/src/proto/name_server.proto @@ -533,6 +533,17 @@ message TableIndex { repeated openmldb.common.ColumnKey column_key = 3; } +message PutUserRequest { + required string host = 1; + required string name = 2; + required string password = 3; +} + +message DeleteUserRequest { + required string host = 1; + required string name = 2; +} + message DeploySQLRequest { optional openmldb.api.ProcedureInfo sp_info = 3; repeated TableIndex index = 4; @@ -602,4 +613,8 @@ service NameServer { rpc DropProcedure(openmldb.api.DropProcedureRequest) returns (GeneralResponse); rpc ShowProcedure(openmldb.api.ShowProcedureRequest) returns (openmldb.api.ShowProcedureResponse); rpc DeploySQL(DeploySQLRequest) returns (DeploySQLResponse); + + // user related interfaces + rpc PutUser(PutUserRequest) returns (GeneralResponse); + rpc DeleteUser(DeleteUserRequest) returns (GeneralResponse); } diff --git a/src/proto/tablet.proto b/src/proto/tablet.proto index a1ae6e72d5a..bc160a01f1e 100755 --- a/src/proto/tablet.proto +++ b/src/proto/tablet.proto @@ -977,4 +977,7 @@ service TabletServer { rpc CreateAggregator(CreateAggregatorRequest) returns (CreateAggregatorResponse); // monitoring interfaces rpc GetAndFlushDeployStats(GAFDeployStatsRequest) returns (DeployStatsResponse); + + // flush privilege + rpc FlushPrivileges(EmptyRequest) returns (GeneralResponse); } diff --git a/src/sdk/mini_cluster.h b/src/sdk/mini_cluster.h index 673e1cb1f61..24521005772 100644 --- a/src/sdk/mini_cluster.h +++ b/src/sdk/mini_cluster.h @@ -26,7 +26,6 @@ #include #include "auth/brpc_authenticator.h" -#include "auth/user_access_manager.h" #include "base/file_util.h" #include "base/glog_wrapper.h" #include "brpc/server.h" @@ -74,12 +73,6 @@ class MiniCluster { } ~MiniCluster() { - for (auto& tablet_user_access_manager : tablet_user_access_managers_) { - if (tablet_user_access_manager) { - delete tablet_user_access_manager; - tablet_user_access_manager = nullptr; - } - } for (auto& tablet_authenticator : tablet_authenticators_) { if (tablet_authenticator) { delete tablet_authenticator; @@ -87,11 +80,6 @@ class MiniCluster { } } - if (user_access_manager_) { - delete user_access_manager_; - user_access_manager_ = nullptr; - } - if (ns_authenticator_) { delete ns_authenticator_; ns_authenticator_ = nullptr; @@ -137,15 +125,9 @@ class MiniCluster { if (!ok) { return false; } - if (!nameserver->GetTableInfo(::openmldb::nameserver::USER_INFO_NAME, ::openmldb::nameserver::INTERNAL_DB, - &user_table_info_)) { - PDLOG(WARNING, "Failed to get table info for user table"); - return false; - } - user_access_manager_ = new openmldb::auth::UserAccessManager(nameserver->GetSystemTableIterator()); ns_authenticator_ = new openmldb::authn::BRPCAuthenticator( [this](const std::string& host, const std::string& username, const std::string& password) { - return user_access_manager_->IsAuthenticated(host, username, password); + return nameserver->IsAuthenticated(host, username, password); }); brpc::ServerOptions options; options.auth = ns_authenticator_; @@ -173,23 +155,12 @@ class MiniCluster { } void Close() { - for (auto& tablet_user_access_manager : tablet_user_access_managers_) { - if (tablet_user_access_manager) { - delete tablet_user_access_manager; - tablet_user_access_manager = nullptr; - } - } for (auto& tablet_authenticator : tablet_authenticators_) { if (tablet_authenticator) { delete tablet_authenticator; tablet_authenticator = nullptr; } } - if (user_access_manager_) { - delete user_access_manager_; - user_access_manager_ = nullptr; - } - if (ns_authenticator_) { delete ns_authenticator_; ns_authenticator_ = nullptr; @@ -244,13 +215,10 @@ class MiniCluster { return false; } - auto tablet_user_access_manager = new openmldb::auth::UserAccessManager(tablet->GetSystemTableIterator()); auto ts_authenticator = new openmldb::authn::BRPCAuthenticator( - [tablet_user_access_manager](const std::string& host, const std::string& username, - const std::string& password) { - return tablet_user_access_manager->IsAuthenticated(host, username, password); + [tablet](const std::string& host, const std::string& username, const std::string& password) { + return tablet->IsAuthenticated(host, username, password); }); - tablet_user_access_managers_.push_back(tablet_user_access_manager); tablet_authenticators_.push_back(ts_authenticator); brpc::ServerOptions options; options.auth = ts_authenticator; @@ -291,22 +259,13 @@ class MiniCluster { std::map tablets_; std::map tb_clients_; openmldb::authn::BRPCAuthenticator* ns_authenticator_; - openmldb::auth::UserAccessManager* user_access_manager_; - std::vector tablet_user_access_managers_; std::vector tablet_authenticators_; - std::shared_ptr<::openmldb::nameserver::TableInfo> user_table_info_; }; class StandaloneEnv { public: StandaloneEnv() : ns_(), ns_client_(nullptr), tb_client_(nullptr) { FLAGS_skip_grant_tables = false; } ~StandaloneEnv() { - for (auto& tablet_user_access_manager : tablet_user_access_managers_) { - if (tablet_user_access_manager) { - delete tablet_user_access_manager; - tablet_user_access_manager = nullptr; - } - } for (auto& tablet_authenticator : tablet_authenticators_) { if (tablet_authenticator) { delete tablet_authenticator; @@ -314,11 +273,6 @@ class StandaloneEnv { } } - if (user_access_manager_) { - delete user_access_manager_; - user_access_manager_ = nullptr; - } - if (ns_authenticator_) { delete ns_authenticator_; ns_authenticator_ = nullptr; @@ -353,15 +307,9 @@ class StandaloneEnv { if (!ok) { return false; } - if (!nameserver->GetTableInfo(::openmldb::nameserver::USER_INFO_NAME, ::openmldb::nameserver::INTERNAL_DB, - &user_table_info_)) { - PDLOG(WARNING, "Failed to get table info for user table"); - return false; - } - user_access_manager_ = new openmldb::auth::UserAccessManager(nameserver->GetSystemTableIterator()); ns_authenticator_ = new openmldb::authn::BRPCAuthenticator( [this](const std::string& host, const std::string& username, const std::string& password) { - return user_access_manager_->IsAuthenticated(host, username, password); + return nameserver->IsAuthenticated(host, username, password); }); brpc::ServerOptions options; options.auth = ns_authenticator_; @@ -387,12 +335,6 @@ class StandaloneEnv { } void Close() { - for (auto& tablet_user_access_manager : tablet_user_access_managers_) { - if (tablet_user_access_manager) { - delete tablet_user_access_manager; - tablet_user_access_manager = nullptr; - } - } for (auto& tablet_authenticator : tablet_authenticators_) { if (tablet_authenticator) { delete tablet_authenticator; @@ -400,11 +342,6 @@ class StandaloneEnv { } } - if (user_access_manager_) { - delete user_access_manager_; - user_access_manager_ = nullptr; - } - if (ns_authenticator_) { delete ns_authenticator_; ns_authenticator_ = nullptr; @@ -436,15 +373,12 @@ class StandaloneEnv { bool ok = tablet->Init("", "", tb_endpoint, ""); if (!ok) { return false; - } + } - auto tablet_user_access_manager = new openmldb::auth::UserAccessManager(tablet->GetSystemTableIterator()); auto ts_authenticator = new openmldb::authn::BRPCAuthenticator( - [tablet_user_access_manager](const std::string& host, const std::string& username, - const std::string& password) { - return tablet_user_access_manager->IsAuthenticated(host, username, password); + [tablet](const std::string& host, const std::string& username, const std::string& password) { + return tablet->IsAuthenticated(host, username, password); }); - tablet_user_access_managers_.push_back(tablet_user_access_manager); tablet_authenticators_.push_back(ts_authenticator); brpc::ServerOptions options; options.auth = ts_authenticator; @@ -474,9 +408,6 @@ class StandaloneEnv { ::openmldb::client::NsClient* ns_client_; ::openmldb::client::TabletClient* tb_client_; openmldb::authn::BRPCAuthenticator* ns_authenticator_; - openmldb::auth::UserAccessManager* user_access_manager_; - std::shared_ptr<::openmldb::nameserver::TableInfo> user_table_info_; - std::vector tablet_user_access_managers_; std::vector tablet_authenticators_; }; diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index 705fbd62400..e58eb8cd2cc 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -4902,49 +4902,51 @@ absl::StatusOr SQLClusterRouter::GetUser(const std::string& name, UserInfo hybridse::sdk::Status SQLClusterRouter::AddUser(const std::string& name, const std::string& password) { auto real_password = password.empty() ? password : codec::Encrypt(password); - uint64_t cur_ts = ::baidu::common::timer::get_micros() / 1000; - std::string sql = absl::StrCat("insert into ", nameserver::USER_INFO_NAME, " values (", - "'%',", // host - "'", name, "','", // user - real_password, "',", // password - cur_ts, ",", // password_last_changed - "0,", // password_expired_time - cur_ts, ", ", // create_time - cur_ts, ",", // update_time - 1, // account_type - ",'',", // privileges - "null" // extra_info - ");"); + hybridse::sdk::Status status; - ExecuteInsert(nameserver::INTERNAL_DB, sql, &status); + + auto ns_client = cluster_sdk_->GetNsClient(); + + bool ok = ns_client->PutUser("%", name, real_password); + + if (!ok) { + status.code = hybridse::common::StatusCode::kRunError; + status.msg = absl::StrCat("Fail to create user: ", name); + } + return status; } hybridse::sdk::Status SQLClusterRouter::UpdateUser(const UserInfo& user_info, const std::string& password) { + auto name = user_info.name; auto real_password = password.empty() ? password : codec::Encrypt(password); - uint64_t cur_ts = ::baidu::common::timer::get_micros() / 1000; - std::string sql = absl::StrCat("insert into ", nameserver::USER_INFO_NAME, " values (", - "'%',", // host - "'", user_info.name, "','", // user - real_password, "',", // password - cur_ts, ",", // password_last_changed - "0,", // password_expired_time - user_info.create_time, ", ", // create_time - cur_ts, ",", // update_time - 1, // account_type - ",'", user_info.privileges, "',", // privileges - "null" // extra_info - ");"); + hybridse::sdk::Status status; - ExecuteInsert(nameserver::INTERNAL_DB, sql, &status); + + auto ns_client = cluster_sdk_->GetNsClient(); + + bool ok = ns_client->PutUser("%", name, real_password); + + if (!ok) { + status.code = hybridse::common::StatusCode::kRunError; + status.msg = absl::StrCat("Fail to update user: ", name); + } + return status; } hybridse::sdk::Status SQLClusterRouter::DeleteUser(const std::string& name) { - std::string sql = absl::StrCat("delete from ", nameserver::USER_INFO_NAME, - " where host = '%' and user = '", name, "';"); hybridse::sdk::Status status; - ExecuteSQL(nameserver::INTERNAL_DB, sql, &status); + + auto ns_client = cluster_sdk_->GetNsClient(); + + bool ok = ns_client->DeleteUser("%", name); + + if (!ok) { + status.code = hybridse::common::StatusCode::kRunError; + status.msg = absl::StrCat("Fail to delete user: ", name); + } + return status; } diff --git a/src/tablet/tablet_impl.cc b/src/tablet/tablet_impl.cc index 8c59a4f9184..2f7544f2847 100644 --- a/src/tablet/tablet_impl.cc +++ b/src/tablet/tablet_impl.cc @@ -154,7 +154,8 @@ TabletImpl::TabletImpl() sp_cache_(std::shared_ptr(new SpCache())), notify_path_(), globalvar_changed_notify_path_(), - startup_mode_(::openmldb::type::StartupMode::kStandalone) {} + startup_mode_(::openmldb::type::StartupMode::kStandalone), + user_access_manager_(GetSystemTableIterator()) {} TabletImpl::~TabletImpl() { task_pool_.Stop(true); @@ -5814,6 +5815,17 @@ void TabletImpl::GetAndFlushDeployStats(::google::protobuf::RpcController* contr response->set_code(ReturnCode::kOk); } +bool TabletImpl::IsAuthenticated(const std::string& host, const std::string& username, const std::string& password) { + return user_access_manager_.IsAuthenticated(host, username, password); +} + +void TabletImpl::FlushPrivileges(::google::protobuf::RpcController* controller, + const ::openmldb::api::EmptyRequest* request, + ::openmldb::api::GeneralResponse* response, ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + user_access_manager_.SyncWithDB(); +} + std::function, std::unique_ptr>>(const std::string& table_name)> TabletImpl::GetSystemTableIterator() { diff --git a/src/tablet/tablet_impl.h b/src/tablet/tablet_impl.h index 89ab1c1befa..d299956cb9e 100644 --- a/src/tablet/tablet_impl.h +++ b/src/tablet/tablet_impl.h @@ -26,6 +26,7 @@ #include #include +#include "auth/user_access_manager.h" #include "base/spinlock.h" #include "brpc/server.h" #include "catalog/tablet_catalog.h" @@ -274,11 +275,15 @@ class TabletImpl : public ::openmldb::api::TabletServer { ::openmldb::api::DeployStatsResponse* response, ::google::protobuf::Closure* done) override; + bool IsAuthenticated(const std::string& host, const std::string& username, const std::string& password); + void FlushPrivileges(::google::protobuf::RpcController* controller, const ::openmldb::api::EmptyRequest* request, + ::openmldb::api::GeneralResponse* response, ::google::protobuf::Closure* done); + + private: std::function, std::unique_ptr>>(const std::string& table_name)> GetSystemTableIterator(); - private: class UpdateAggrClosure : public Closure { public: explicit UpdateAggrClosure(const std::function& callback) : callback_(callback) {} @@ -489,6 +494,7 @@ class TabletImpl : public ::openmldb::api::TabletServer { std::unique_ptr deploy_collector_; std::atomic memory_used_ = 0; std::atomic system_memory_usage_rate_ = 0; // [0, 100] + openmldb::auth::UserAccessManager user_access_manager_; }; } // namespace tablet From 59d79f6d116fd3fbb9496d47ca9432375d6a06f0 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Mon, 27 May 2024 14:08:42 +0800 Subject: [PATCH 13/23] feat(parser): simple ANSI SQL rewriter (#3934) * feat(parser): simple ANSI SQL rewriter * feat(draft): translate request mode query * feat: request query rewriter * test: tpc rewrite cases * feat(rewrite): enable ansi sql rewriter in `ExecuteSQL` You may explicitly set this feature on via `set session ansi_sql_rewriter = 'true'` TODO: this rewriter feature should be off by default --- hybridse/src/CMakeLists.txt | 1 + hybridse/src/rewriter/ast_rewriter.cc | 573 +++++++++++++++++++++ hybridse/src/rewriter/ast_rewriter.h | 32 ++ hybridse/src/rewriter/ast_rewriter_test.cc | 237 +++++++++ src/sdk/sql_cluster_router.cc | 39 +- src/sdk/sql_cluster_router.h | 2 + 6 files changed, 883 insertions(+), 1 deletion(-) create mode 100644 hybridse/src/rewriter/ast_rewriter.cc create mode 100644 hybridse/src/rewriter/ast_rewriter.h create mode 100644 hybridse/src/rewriter/ast_rewriter_test.cc diff --git a/hybridse/src/CMakeLists.txt b/hybridse/src/CMakeLists.txt index 80c5cc2a5a3..4f25d87ab70 100644 --- a/hybridse/src/CMakeLists.txt +++ b/hybridse/src/CMakeLists.txt @@ -48,6 +48,7 @@ hybridse_add_src_and_tests(vm) hybridse_add_src_and_tests(codec) hybridse_add_src_and_tests(case) hybridse_add_src_and_tests(passes) +hybridse_add_src_and_tests(rewriter) get_property(SRC_FILE_LIST_STR GLOBAL PROPERTY PROP_SRC_FILE_LIST) string(REPLACE " " ";" SRC_FILE_LIST ${SRC_FILE_LIST_STR}) diff --git a/hybridse/src/rewriter/ast_rewriter.cc b/hybridse/src/rewriter/ast_rewriter.cc new file mode 100644 index 00000000000..9dc90ffdee3 --- /dev/null +++ b/hybridse/src/rewriter/ast_rewriter.cc @@ -0,0 +1,573 @@ +/** + * Copyright (c) 2024 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "rewriter/ast_rewriter.h" + +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "plan/plan_api.h" +#include "zetasql/parser/parse_tree_manual.h" +#include "zetasql/parser/parser.h" +#include "zetasql/parser/unparser.h" + +namespace hybridse { +namespace rewriter { + +// unparser that make some rewrites so outputed SQL is +// compatible with ANSI SQL as much as can +class LastJoinRewriteUnparser : public zetasql::parser::Unparser { + public: + explicit LastJoinRewriteUnparser(std::string* unparsed) : zetasql::parser::Unparser(unparsed) {} + ~LastJoinRewriteUnparser() override {} + LastJoinRewriteUnparser(const LastJoinRewriteUnparser&) = delete; + LastJoinRewriteUnparser& operator=(const LastJoinRewriteUnparser&) = delete; + + void visitASTSelect(const zetasql::ASTSelect* node, void* data) override { + while (true) { + absl::string_view filter_col; + + // 1. filter condition is 'col = 1' + if (node->where_clause() != nullptr && + node->where_clause()->expression()->node_kind() == zetasql::AST_BINARY_EXPRESSION) { + auto expr = node->where_clause()->expression()->GetAsOrNull(); + if (expr && expr->op() == zetasql::ASTBinaryExpression::Op::EQ && !expr->is_not()) { + { + auto lval = expr->lhs()->GetAsOrNull(); + auto rval = expr->rhs()->GetAsOrNull(); + if (lval && rval && lval->image() == "1") { + // TODO(someone): + // 1. consider lval->iamge() as '1L' + // 2. consider rval as . + filter_col = rval->last_name()->GetAsStringView(); + } + } + if (filter_col.empty()) { + auto lval = expr->rhs()->GetAsOrNull(); + auto rval = expr->lhs()->GetAsOrNull(); + if (lval && rval && lval->image() == "1") { + // TODO(someone): + // 1. consider lval->iamge() as '1L' + // 2. consider rval as . + filter_col = rval->last_name()->GetAsStringView(); + } + } + } + } + + // 2. FROM a subquery: SELECT ... t1 LEFT JOIN t2 WINDOW + const zetasql::ASTPathExpression* join_lhs_key = nullptr; + const zetasql::ASTPathExpression* join_rhs_key = nullptr; + if (node->from_clause() == nullptr) { + break; + } + auto sub = node->from_clause()->table_expression()->GetAsOrNull(); + if (!sub) { + break; + } + auto subquery = sub->subquery(); + if (subquery->with_clause() != nullptr || subquery->order_by() != nullptr || + subquery->limit_offset() != nullptr) { + break; + } + + auto inner_select = subquery->query_expr()->GetAsOrNull(); + if (!inner_select) { + break; + } + // select have window + if (inner_select->window_clause() == nullptr || inner_select->from_clause() == nullptr) { + break; + } + + // 3. CHECK FROM CLAUSE: must 't1 LEFT JOIN t2 on t1.key = t2.key' + if (!inner_select->from_clause()) { + break; + } + auto join = inner_select->from_clause()->table_expression()->GetAsOrNull(); + if (join == nullptr || join->join_type() != zetasql::ASTJoin::LEFT || join->on_clause() == nullptr) { + break; + } + auto on_expr = join->on_clause()->expression()->GetAsOrNull(); + if (on_expr == nullptr || on_expr->is_not() || on_expr->op() != zetasql::ASTBinaryExpression::EQ) { + break; + } + + // still might null + join_lhs_key = on_expr->lhs()->GetAsOrNull(); + join_rhs_key = on_expr->rhs()->GetAsOrNull(); + if (join_lhs_key == nullptr || join_rhs_key == nullptr) { + break; + } + + // 3. CHECK row_id is row_number() over w FROM select_list + bool found = false; + absl::string_view window_name; + for (auto col : inner_select->select_list()->columns()) { + if (col->alias() && col->alias()->GetAsStringView() == filter_col) { + auto agg_func = col->expression()->GetAsOrNull(); + if (!agg_func || !agg_func->function()) { + break; + } + + auto w = agg_func->window_spec(); + if (!w || w->base_window_name() == nullptr) { + break; + } + window_name = w->base_window_name()->GetAsStringView(); + + auto ph = agg_func->function()->function(); + if (ph->num_names() == 1 && + absl::AsciiStrToLower(ph->first_name()->GetAsStringView()) == "row_number") { + opt_out_row_number_col_ = col; + found = true; + break; + } + } + } + if (!found || window_name.empty()) { + break; + } + + // 4. CHECK WINDOW CLAUSE + { + if (inner_select->window_clause()->windows().size() != 1) { + // targeting single window only + break; + } + auto win = inner_select->window_clause()->windows().front(); + if (win->name()->GetAsStringView() != window_name) { + break; + } + auto spec = win->window_spec(); + if (spec->window_frame() != nullptr || spec->partition_by() == nullptr || spec->order_by() == nullptr) { + // TODO(someone): allow unbounded window frame + break; + } + + // PARTITION BY contains join_lhs_key + // ORDER BY is join_rhs_key + bool partition_meet = false; + for (auto expr : spec->partition_by()->partitioning_expressions()) { + auto e = expr->GetAsOrNull(); + if (e) { + if (e->last_name()->GetAsStringView() == join_lhs_key->last_name()->GetAsStringView()) { + partition_meet = true; + } + } + } + + if (!partition_meet) { + break; + } + + if (spec->order_by()->ordering_expressions().size() != 1) { + break; + } + + if (spec->order_by()->ordering_expressions().front()->ordering_spec() != + zetasql::ASTOrderingExpression::DESC) { + break; + } + + auto e = spec->order_by() + ->ordering_expressions() + .front() + ->expression() + ->GetAsOrNull(); + if (!e) { + break; + } + + // rewrite + { + opt_out_window_ = inner_select->window_clause(); + opt_out_where_ = node->where_clause(); + opt_join_ = join; + opt_in_last_join_order_by_ = e; + absl::Cleanup clean = [&]() { + opt_out_window_ = nullptr; + opt_out_where_ = nullptr; + opt_out_row_number_col_ = nullptr; + opt_join_ = nullptr; + }; + + // inline zetasql::parser::Unparser::visitASTSelect(node, data); + { + PrintOpenParenIfNeeded(node); + println(); + print("SELECT"); + if (node->hint() != nullptr) { + node->hint()->Accept(this, data); + } + if (node->anonymization_options() != nullptr) { + print("WITH ANONYMIZATION OPTIONS"); + node->anonymization_options()->Accept(this, data); + } + if (node->distinct()) { + print("DISTINCT"); + } + + // Visit all children except hint() and anonymization_options, which we + // processed above. We can't just use visitASTChildren(node, data) because + // we need to insert the DISTINCT modifier after the hint and anonymization + // nodes and before everything else. + for (int i = 0; i < node->num_children(); ++i) { + const zetasql::ASTNode* child = node->child(i); + if (child != node->hint() && child != node->anonymization_options()) { + child->Accept(this, data); + } + } + + println(); + PrintCloseParenIfNeeded(node); + } + + return; + } + } + + break; + } + + zetasql::parser::Unparser::visitASTSelect(node, data); + } + + void visitASTJoin(const zetasql::ASTJoin* node, void* data) override { + if (opt_join_ && opt_join_ == node) { + node->child(0)->Accept(this, data); + + if (node->join_type() == zetasql::ASTJoin::COMMA) { + print(","); + } else { + println(); + if (node->natural()) { + print("NATURAL"); + } + print("LAST"); + print(node->GetSQLForJoinHint()); + + print("JOIN"); + } + println(); + + // This will print hints, the rhs, and the ON or USING clause. + for (int i = 1; i < node->num_children(); i++) { + node->child(i)->Accept(this, data); + if (opt_in_last_join_order_by_ && node->child(i)->IsTableExpression()) { + print("ORDER BY"); + opt_in_last_join_order_by_->Accept(this, data); + } + } + + return; + } + + zetasql::parser::Unparser::visitASTJoin(node, data); + } + + void visitASTSelectList(const zetasql::ASTSelectList* node, void* data) override { + println(); + { + for (int i = 0; i < node->num_children(); i++) { + if (opt_out_row_number_col_ && node->columns(i) == opt_out_row_number_col_) { + continue; + } + if (i > 0) { + println(","); + } + node->child(i)->Accept(this, data); + } + } + } + + void visitASTWindowClause(const zetasql::ASTWindowClause* node, void* data) override { + if (opt_out_window_ && opt_out_window_ == node) { + return; + } + + zetasql::parser::Unparser::visitASTWindowClause(node, data); + } + + void visitASTWhereClause(const zetasql::ASTWhereClause* node, void* data) override { + if (opt_out_where_ && opt_out_where_ == node) { + return; + } + zetasql::parser::Unparser::visitASTWhereClause(node, data); + } + + private: + const zetasql::ASTWindowClause* opt_out_window_ = nullptr; + const zetasql::ASTWhereClause* opt_out_where_ = nullptr; + const zetasql::ASTSelectColumn* opt_out_row_number_col_ = nullptr; + const zetasql::ASTJoin* opt_join_ = nullptr; + const zetasql::ASTPathExpression* opt_in_last_join_order_by_ = nullptr; +}; + +// SELECT: +// WHERE col = 0 +// FROM (subquery): +// subquery is UNION ALL, or contains left-most query is UNION ALL +// and UNION ALL is select const ..., 0 as col UNION ALL (select .., 1 as col table) +class RequestQueryRewriteUnparser : public zetasql::parser::Unparser { + public: + explicit RequestQueryRewriteUnparser(std::string* unparsed) : zetasql::parser::Unparser(unparsed) {} + ~RequestQueryRewriteUnparser() override {} + RequestQueryRewriteUnparser(const RequestQueryRewriteUnparser&) = delete; + RequestQueryRewriteUnparser& operator=(const RequestQueryRewriteUnparser&) = delete; + + void visitASTSelect(const zetasql::ASTSelect* node, void* data) override { + while (true) { + if (outer_most_select_ != nullptr) { + break; + } + + outer_most_select_ = node; + if (node->where_clause() == nullptr) { + break; + } + absl::string_view filter_col; + const zetasql::ASTExpression* filter_expr; + + // 1. filter condition is 'col = 0' + if (node->where_clause()->expression()->node_kind() != zetasql::AST_BINARY_EXPRESSION) { + break; + } + auto expr = node->where_clause()->expression()->GetAsOrNull(); + if (!expr || expr->op() != zetasql::ASTBinaryExpression::Op::EQ || expr->is_not()) { + break; + } + { + auto rval = expr->rhs()->GetAsOrNull(); + if (rval) { + // TODO(someone): + // 2. consider rval as . + filter_col = rval->last_name()->GetAsStringView(); + filter_expr = expr->lhs(); + } + } + if (filter_col.empty()) { + auto rval = expr->lhs()->GetAsOrNull(); + if (rval) { + // TODO(someone): + // 2. consider rval as . + filter_col = rval->last_name()->GetAsStringView(); + filter_expr = expr->rhs(); + } + } + if (filter_col.empty() || !filter_expr) { + break; + } + + if (node->from_clause() == nullptr) { + break; + } + auto sub = node->from_clause()->table_expression()->GetAsOrNull(); + if (!sub) { + break; + } + auto subquery = sub->subquery(); + + findUnionAllForQuery(subquery, filter_col, filter_expr, node->where_clause()); + + break; // fallback normal + } + + zetasql::parser::Unparser::visitASTSelect(node, data); + } + + void visitASTSetOperation(const zetasql::ASTSetOperation* node, void* data) override { + if (node == detected_request_block_) { + node->inputs().back()->Accept(this, data); + } else { + zetasql::parser::Unparser::visitASTSetOperation(node, data); + } + } + + void visitASTQueryStatement(const zetasql::ASTQueryStatement* node, void* data) override { + node->query()->Accept(this, data); + if (!list_.empty() && !node->config_clause()) { + constSelectListAsConfigClause(list_, data); + } else { + if (node->config_clause() != nullptr) { + println(); + node->config_clause()->Accept(this, data); + } + } + } + + void visitASTWhereClause(const zetasql::ASTWhereClause* node, void* data) override { + if (node != filter_clause_) { + zetasql::parser::Unparser::visitASTWhereClause(node, data); + } + } + + private: + void findUnionAllForQuery(const zetasql::ASTQuery* query, absl::string_view label_name, + const zetasql::ASTExpression* filter_expr, const zetasql::ASTWhereClause* filter) { + if (!query) { + return; + } + auto qe = query->query_expr(); + switch (qe->node_kind()) { + case zetasql::AST_SET_OPERATION: { + auto set = qe->GetAsOrNull(); + if (set && set->op_type() == zetasql::ASTSetOperation::UNION && set->distinct() == false && + set->hint() == nullptr && set->inputs().size() == 2) { + [[maybe_unused]] bool ret = + findUnionAllInput(set->inputs().at(0), set->inputs().at(1), label_name, filter_expr, filter) || + findUnionAllInput(set->inputs().at(0), set->inputs().at(1), label_name, filter_expr, filter); + if (ret) { + detected_request_block_ = set; + } + } + break; + } + case zetasql::AST_QUERY: { + findUnionAllForQuery(qe->GetAsOrNull(), label_name, filter_expr, filter); + break; + } + case zetasql::AST_SELECT: { + auto select = qe->GetAsOrNull(); + if (select->from_clause() && + select->from_clause()->table_expression()->node_kind() == zetasql::AST_TABLE_SUBQUERY) { + auto sub = select->from_clause()->table_expression()->GetAsOrNull(); + if (sub && sub->subquery()) { + findUnionAllForQuery(sub->subquery(), label_name, filter_expr, filter); + } + } + break; + } + default: + break; + } + } + + void constSelectListAsConfigClause(const std::vector& selects, void* data) { + print("CONFIG (execute_mode = 'request', values = ("); + for (int i = 0; i < selects.size(); ++i) { + selects.at(i)->Accept(this, data); + if (i + 1 < selects.size()) { + print(","); + } + } + print(") )"); + } + + bool findUnionAllInput(const zetasql::ASTQueryExpression* lhs, const zetasql::ASTQueryExpression* rhs, + absl::string_view label_name, const zetasql::ASTExpression* filter_expr, + const zetasql::ASTWhereClause* filter) { + // lhs is select const + label_name of value 0 + auto lselect = lhs->GetAsOrNull(); + if (!lselect || lselect->num_children() > 1) { + // only select_list required, otherwise size > 1 + return false; + } + + bool has_label_col_0 = false; + const zetasql::ASTExpression* label_expr_0 = nullptr; + std::vector vec; + for (auto col : lselect->select_list()->columns()) { + if (col->alias() && col->alias()->GetAsStringView() == label_name) { + has_label_col_0 = true; + label_expr_0 = col->expression(); + } else { + vec.push_back(col->expression()); + } + } + + // rhs is simple selects from table + label_name of value 1 + auto rselect = rhs->GetAsOrNull(); + if (!rselect || rselect->num_children() > 2 || !rselect->from_clause()) { + // only select_list + from_clause required + return false; + } + if (rselect->from_clause()->table_expression()->node_kind() != zetasql::AST_TABLE_PATH_EXPRESSION) { + return false; + } + + bool has_label_col_1 = false; + const zetasql::ASTExpression* label_expr_1 = nullptr; + for (auto col : rselect->select_list()->columns()) { + if (col->alias() && col->alias()->GetAsStringView() == label_name) { + has_label_col_1 = true; + label_expr_1 = col->expression(); + } + } + + LOG(INFO) << "label expr 0: " << label_expr_0->SingleNodeDebugString(); + LOG(INFO) << "label expr 1: " << label_expr_1->SingleNodeDebugString(); + LOG(INFO) << "filter expr: " << filter_expr->SingleNodeDebugString(); + + if (has_label_col_0 && has_label_col_1 && + label_expr_0->SingleNodeDebugString() != label_expr_1->SingleNodeDebugString() && + label_expr_0->SingleNodeDebugString() == filter_expr->SingleNodeDebugString()) { + list_ = vec; + filter_clause_ = filter; + return true; + } + + return false; + } + + private: + const zetasql::ASTSelect* outer_most_select_ = nullptr; + // detected request query block, set by when visiting outer most query + const zetasql::ASTSetOperation* detected_request_block_ = nullptr; + const zetasql::ASTWhereClause* filter_clause_; + + std::vector list_; +}; + +absl::StatusOr Rewrite(absl::string_view query) { + auto str = std::string(query); + { + std::unique_ptr ast; + auto s = hybridse::plan::ParseStatement(str, &ast); + if (!s.ok()) { + return s; + } + + if (ast->statement() && ast->statement()->node_kind() == zetasql::AST_QUERY_STATEMENT) { + std::string unparsed_; + LastJoinRewriteUnparser unparser(&unparsed_); + ast->statement()->Accept(&unparser, nullptr); + unparser.FlushLine(); + str = unparsed_; + } + } + { + std::unique_ptr ast; + auto s = hybridse::plan::ParseStatement(str, &ast); + if (!s.ok()) { + return s; + } + + if (ast->statement() && ast->statement()->node_kind() == zetasql::AST_QUERY_STATEMENT) { + std::string unparsed_; + RequestQueryRewriteUnparser unparser(&unparsed_); + ast->statement()->Accept(&unparser, nullptr); + unparser.FlushLine(); + str = unparsed_; + } + } + + return str; +} + +} // namespace rewriter +} // namespace hybridse diff --git a/hybridse/src/rewriter/ast_rewriter.h b/hybridse/src/rewriter/ast_rewriter.h new file mode 100644 index 00000000000..17ea7ad0d04 --- /dev/null +++ b/hybridse/src/rewriter/ast_rewriter.h @@ -0,0 +1,32 @@ +/** + * Copyright (c) 2024 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HYBRIDSE_SRC_REWRITER_AST_REWRITER_H_ +#define HYBRIDSE_SRC_REWRITER_AST_REWRITER_H_ + +#include + +#include "absl/status/statusor.h" + +namespace hybridse { +namespace rewriter { + +absl::StatusOr Rewrite(absl::string_view query); + +} // namespace rewriter +} // namespace hybridse + +#endif // HYBRIDSE_SRC_REWRITER_AST_REWRITER_H_ diff --git a/hybridse/src/rewriter/ast_rewriter_test.cc b/hybridse/src/rewriter/ast_rewriter_test.cc new file mode 100644 index 00000000000..7585ada71a6 --- /dev/null +++ b/hybridse/src/rewriter/ast_rewriter_test.cc @@ -0,0 +1,237 @@ +/** + * Copyright 2024 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "rewriter/ast_rewriter.h" + +#include +#include + +#include "absl/strings/ascii.h" +#include "gtest/gtest.h" +#include "plan/plan_api.h" +#include "zetasql/parser/parser.h" + +namespace hybridse { +namespace rewriter { + +struct Case { + absl::string_view in; + absl::string_view out; +}; + +class ASTRewriterTest : public ::testing::TestWithParam {}; + +std::vector strip_cases = { + // eliminate LEFT JOIN WINDOW -> LAST JOIN + {R"s( + SELECT id, val, k, ts, idr, valr FROM ( + SELECT t1.*, t2.id as idr, t2.val as valr, row_number() over w as any_id + FROM t1 LEFT JOIN t2 ON t1.k = t2.k + WINDOW w as (PARTITION BY t1.id,t1.k order by t2.ts desc) + ) t WHERE any_id = 1)s", + R"e( +SELECT + id, + val, + k, + ts, + idr, + valr +FROM + ( + SELECT + t1.*, + t2.id AS idr, + t2.val AS valr + FROM + t1 + LAST JOIN + t2 + ORDER BY t2.ts + ON t1.k = t2.k + ) AS t +)e"}, + {R"( +SELECT id, k, agg +FROM ( + SELECT id, k, label, count(val) over w as agg + FROM ( + SELECT 6 as id, "xxx" as val, 10 as k, 9000 as ts, 0 as label + UNION ALL + SELECT *, 1 as label FROM t1 + ) t + WINDOW w as (PARTITION BY k ORDER BY ts rows between unbounded preceding and current row) +) t WHERE label = 0)", + R"( +SELECT + id, + k, + agg +FROM + ( + SELECT + id, + k, + label, + count(val) OVER (w) AS agg + FROM + ( + SELECT + *, + 1 AS label + FROM + t1 + ) AS t + WINDOW w AS (PARTITION BY k + ORDER BY ts ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + ) AS t +CONFIG (execute_mode = 'request', values = (6, "xxx", 10, 9000) ) +)"}, + // simplist request query + {R"s( + SELECT id, k + FROM ( + SELECT 6 as id, "xxx" as val, 10 as k, 9000 as ts, 0 as label + UNION ALL + SELECT *, 1 as label FROM t1 + ) t WHERE label = 0)s", + R"s(SELECT + id, + k +FROM + ( + SELECT + *, + 1 AS label + FROM + t1 + ) AS t +CONFIG (execute_mode = 'request', values = (6, "xxx", 10, 9000) ) +)s"}, + + // TPC-C case + {R"(SELECT C_ID, C_CITY, C_STATE, C_CREDIT, C_CREDIT_LIM, C_BALANCE, C_PAYMENT_CNT, C_DELIVERY_CNT + FROM ( + SELECT C_ID, C_CITY, C_STATE, C_CREDIT, C_CREDIT_LIM, C_BALANCE, C_PAYMENT_CNT, C_DELIVERY_CNT, label FROM ( + SELECT 1 AS C_ID, 1 AS C_D_ID, 1 AS C_W_ID, "John" AS C_FIRST, "M" AS C_MIDDLE, "Smith" AS C_LAST, "123 Main St" AS C_STREET_1, "Apt 101" AS C_STREET_2, "Springfield" AS C_CITY, "IL" AS C_STATE, 12345 AS C_ZIP, "555-123-4567" AS C_PHONE, timestamp("2024-01-01 00:00:00") AS C_SINCE, "BC" AS C_CREDIT, 10000.0 AS C_CREDIT_LIM, 0.5 AS C_DISCOUNT, 5000.0 AS C_BALANCE, 0.0 AS C_YTD_PAYMENT, 0 AS C_PAYMENT_CNT, 0 AS C_DELIVERY_CNT, "Additional customer data..." AS C_DATA, 0 as label + UNION ALL + SELECT *, 1 as label FROM CUSTOMER + ) t + ) t WHERE label = 0)", + R"s( +SELECT + C_ID, + C_CITY, + C_STATE, + C_CREDIT, + C_CREDIT_LIM, + C_BALANCE, + C_PAYMENT_CNT, + C_DELIVERY_CNT +FROM + ( + SELECT + C_ID, + C_CITY, + C_STATE, + C_CREDIT, + C_CREDIT_LIM, + C_BALANCE, + C_PAYMENT_CNT, + C_DELIVERY_CNT, + label + FROM + ( + SELECT + *, + 1 AS label + FROM + CUSTOMER + ) AS t + ) AS t +CONFIG (execute_mode = 'request', values = (1, 1, 1, "John", "M", "Smith", "123 Main St", "Apt 101", +"Springfield", "IL", 12345, "555-123-4567", timestamp("2024-01-01 00:00:00"), "BC", 10000.0, 0.5, 5000.0, +0.0, 0, 0, "Additional customer data...") ) + )s"}, + + {R"( +SELECT C_ID, C_CITY, C_STATE, C_CREDIT, C_CREDIT_LIM, C_BALANCE, C_PAYMENT_CNT, C_DELIVERY_CNT + FROM ( + SELECT C_ID, C_CITY, C_STATE, C_CREDIT, C_CREDIT_LIM, C_BALANCE, C_PAYMENT_CNT, C_DELIVERY_CNT, label FROM ( + SELECT 1 AS C_ID, 1 AS C_D_ID, 1 AS C_W_ID, "John" AS C_FIRST, "M" AS C_MIDDLE, "Smith" AS C_LAST, "123 Main St" AS C_STREET_1, "Apt 101" AS C_STREET_2, "Springfield" AS C_CITY, "IL" AS C_STATE, 12345 AS C_ZIP, "555-123-4567" AS C_PHONE, timestamp("2024-01-01 00:00:00") AS C_SINCE, "BC" AS C_CREDIT, 10000.0 AS C_CREDIT_LIM, 0.5 AS C_DISCOUNT, 9000.0 AS C_BALANCE, 0.0 AS C_YTD_PAYMENT, 0 AS C_PAYMENT_CNT, 0 AS C_DELIVERY_CNT, "Additional customer data..." AS C_DATA, 0 as label + UNION ALL + SELECT *, 1 as label FROM CUSTOMER + ) t + ) t WHERE label = 0)", + R"( +SELECT + C_ID, + C_CITY, + C_STATE, + C_CREDIT, + C_CREDIT_LIM, + C_BALANCE, + C_PAYMENT_CNT, + C_DELIVERY_CNT +FROM + ( + SELECT + C_ID, + C_CITY, + C_STATE, + C_CREDIT, + C_CREDIT_LIM, + C_BALANCE, + C_PAYMENT_CNT, + C_DELIVERY_CNT, + label + FROM + ( + SELECT + *, + 1 AS label + FROM + CUSTOMER + ) AS t + ) AS t +CONFIG (execute_mode = 'request', values = (1, 1, 1, "John", "M", "Smith", "123 Main St", "Apt 101", +"Springfield", "IL", 12345, "555-123-4567", timestamp("2024-01-01 00:00:00"), "BC", 10000.0, 0.5, 9000.0, +0.0, 0, 0, "Additional customer data...") ) +)"}, +}; + +INSTANTIATE_TEST_SUITE_P(Rules, ASTRewriterTest, ::testing::ValuesIn(strip_cases)); + +TEST_P(ASTRewriterTest, Correctness) { + auto& c = GetParam(); + + auto s = hybridse::rewriter::Rewrite(c.in); + ASSERT_TRUE(s.ok()) << s.status(); + + ASSERT_EQ(absl::StripAsciiWhitespace(c.out), absl::StripAsciiWhitespace(s.value())); + + std::unique_ptr out; + auto ss = ::hybridse::plan::ParseStatement(s.value(), &out); + ASSERT_TRUE(ss.ok()) << ss; +} + +} // namespace rewriter +} // namespace hybridse + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index e58eb8cd2cc..068538ba5e0 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -50,6 +50,7 @@ #include "plan/plan_api.h" #include "proto/fe_common.pb.h" #include "proto/tablet.pb.h" +#include "rewriter/ast_rewriter.h" #include "rpc/rpc_client.h" #include "schema/schema_adapter.h" #include "sdk/base.h" @@ -2676,14 +2677,34 @@ std::shared_ptr SQLClusterRouter::ExecuteSQL(const std } std::shared_ptr SQLClusterRouter::ExecuteSQL( - const std::string& db, const std::string& sql, std::shared_ptr parameter, + const std::string& db, const std::string& str, std::shared_ptr parameter, bool is_online_mode, bool is_sync_job, int offline_job_timeout, hybridse::sdk::Status* status) { RET_IF_NULL_AND_WARN(status, "output status is nullptr"); // functions we called later may not change the status if it's succeed. So if we pass error status here, we'll get a // fake error status->SetOK(); + + std::string sql = str; hybridse::vm::SqlContext ctx; + if (ANSISQLRewriterEnabled()) { + // If true, enable the ANSI SQL rewriter that would rewrite some SQL query + // for pre-defined pattern to OpenMLDB SQL extensions. Rewrite phase is before general SQL compilation. + // + // OpenMLDB SQL extensions, such as request mode query or LAST JOIN, would be helpful + // to simplify those that comes from like SparkSQL, and reserve the same semantics meaning. + // + // Rewrite rules are based on ASTNode, possibly lack some semantic checks. Turn it off if things + // go abnormal during rewrite phase. + auto s = hybridse::rewriter::Rewrite(sql); + if (s.ok()) { + LOG(INFO) << "rewrited: " << s.value(); + sql = s.value(); + } else { + LOG(WARNING) << s.status(); + } + } ctx.sql = sql; + auto sql_status = hybridse::plan::PlanAPI::CreatePlanTreeFromScript(&ctx); if (!sql_status.isOK()) { COPY_PREPEND_AND_WARN(status, sql_status, "create logic plan tree failed"); @@ -3196,6 +3217,18 @@ bool SQLClusterRouter::IsSyncJob() { return false; } +bool SQLClusterRouter::ANSISQLRewriterEnabled() { + // TODO(xxx): mark fn const + + std::lock_guard<::openmldb::base::SpinMutex> lock(mu_); + auto it = session_variables_.find("ansi_sql_rewriter"); + if (it != session_variables_.end() && it->second == "false") { + return false; + } + // TODO(xxx): always disable by default + return true; +} + int SQLClusterRouter::GetJobTimeout() { std::lock_guard<::openmldb::base::SpinMutex> lock(mu_); auto it = session_variables_.find("job_timeout"); @@ -3267,6 +3300,10 @@ ::hybridse::sdk::Status SQLClusterRouter::SetVariable(hybridse::node::SetPlanNod return {StatusCode::kCmdError, "Fail to parse spark config, set like 'spark.executor.memory=2g;spark.executor.cores=2'"}; } + } else if (key == "ansi_sql_rewriter") { + if (value != "true" && value != "false") { + return {StatusCode::kCmdError, "the value of " + key + " must be true|false"}; + } } else { return {}; } diff --git a/src/sdk/sql_cluster_router.h b/src/sdk/sql_cluster_router.h index 3d13cafa240..e917c170a14 100644 --- a/src/sdk/sql_cluster_router.h +++ b/src/sdk/sql_cluster_router.h @@ -443,6 +443,8 @@ class SQLClusterRouter : public SQLRouter { const base::Slice& value, const std::vector>& tablets); + bool ANSISQLRewriterEnabled(); + private: std::shared_ptr options_; std::string db_; From e307fd939b5dc7346964fd10358f8b2b29417750 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 19 Jun 2024 13:20:08 +0800 Subject: [PATCH 14/23] build(deps-dev): bump urllib3 from 1.26.18 to 1.26.19 in /docs (#3948) Bumps [urllib3](https://github.com/urllib3/urllib3) from 1.26.18 to 1.26.19. - [Release notes](https://github.com/urllib3/urllib3/releases) - [Changelog](https://github.com/urllib3/urllib3/blob/1.26.19/CHANGES.rst) - [Commits](https://github.com/urllib3/urllib3/compare/1.26.18...1.26.19) --- updated-dependencies: - dependency-name: urllib3 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/poetry.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/poetry.lock b/docs/poetry.lock index ca6c9ccb8c6..39577275304 100644 --- a/docs/poetry.lock +++ b/docs/poetry.lock @@ -670,13 +670,13 @@ test = ["coverage", "pytest", "pytest-cov"] [[package]] name = "urllib3" -version = "1.26.18" +version = "1.26.19" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ - {file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"}, - {file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"}, + {file = "urllib3-1.26.19-py2.py3-none-any.whl", hash = "sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3"}, + {file = "urllib3-1.26.19.tar.gz", hash = "sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429"}, ] [package.extras] From 818d292cc5190c65ef3ead678f6742e828d0f163 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Wed, 26 Jun 2024 12:34:11 +0800 Subject: [PATCH 15/23] feat(udf): isin (#3939) --- cases/query/udf_query.yaml | 19 ++++++++++ hybridse/src/udf/default_defs/array_def.cc | 41 ++++++++++++++++++++-- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/cases/query/udf_query.yaml b/cases/query/udf_query.yaml index fefe1380dbb..ee9cad2d667 100644 --- a/cases/query/udf_query.yaml +++ b/cases/query/udf_query.yaml @@ -536,6 +536,25 @@ cases: data: | true, true, false, false, true, false, true, false, true, false, true + - id: isin + mode: request-unsupport + inputs: + - name: t1 + columns: ["col1:int32", "std_ts:timestamp", "col2:string"] + indexs: ["index1:col1:std_ts"] + rows: + - [1, 1590115420001, "ABCabcabc"] + sql: | + select + isin(2, [2,2]) as c0, + isin(cast(3 as int64), ARRAY[NULL, 1, 2]) as c1 + expect: + columns: + - c0 bool + - c1 bool + data: | + true, false + - id: array_split mode: request-unsupport inputs: diff --git a/hybridse/src/udf/default_defs/array_def.cc b/hybridse/src/udf/default_defs/array_def.cc index a1f35f38e35..b5c36bc3d7e 100644 --- a/hybridse/src/udf/default_defs/array_def.cc +++ b/hybridse/src/udf/default_defs/array_def.cc @@ -37,8 +37,30 @@ struct ArrayContains { // - bool/intxx/float/double -> bool/intxx/float/double // - Timestamp/Date/StringRef -> Timestamp*/Date*/StringRef* bool operator()(ArrayRef* arr, ParamType v, bool is_null) { - // NOTE: array_contains([null], null) returns null - // this might not expected + for (uint64_t i = 0; i < arr->size; ++i) { + if constexpr (std::is_pointer_v) { + // null or same value returns true + if ((is_null && arr->nullables[i]) || (!arr->nullables[i] && *arr->raw[i] == *v)) { + return true; + } + } else { + if ((is_null && arr->nullables[i]) || (!arr->nullables[i] && arr->raw[i] == v)) { + return true; + } + } + } + return false; + } +}; + +template +struct IsIn { + // udf registry types + using Args = std::tuple, ArrayRef>; + + using ParamType = typename DataTypeTrait::CCallArgType; + + bool operator()(ParamType v, bool is_null, ArrayRef* arr) { for (uint64_t i = 0; i < arr->size; ++i) { if constexpr (std::is_pointer_v) { // null or same value returns true @@ -98,6 +120,21 @@ void DefaultUdfLibrary::InitArrayUdfs() { @since 0.7.0 )"); + RegisterExternalTemplate("isin") + .args_in() + .doc(R"( + @brief isin(value, array) - Returns true if the array contains the value. + + Example: + + @code{.sql} + select isin(2, [2,2]) as c0; + -- output true + @endcode + + @since 0.9.1 + )"); + RegisterExternal("split_array") .returns>() .return_by_arg(true) From 6b06e38a7c8e4f794503afb8b8ccfab8a1b96297 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Wed, 26 Jun 2024 12:52:41 +0800 Subject: [PATCH 16/23] feat(#3916): support @@execute_mode = 'request' (#3924) --- hybridse/include/vm/engine.h | 4 ++ hybridse/src/vm/engine.cc | 45 +++++++++++++++++------ hybridse/src/vm/engine_context.cc | 3 +- src/client/tablet_client.cc | 3 +- src/client/tablet_client.h | 2 +- src/sdk/internal/system_variable.cc | 57 +++++++++++++++++++++++++++++ src/sdk/internal/system_variable.h | 40 ++++++++++++++++++++ src/sdk/sql_cluster_router.cc | 37 ++++++++++++++----- src/sdk/sql_cluster_router.h | 5 ++- src/sdk/sql_router.h | 2 +- src/sdk/sql_router_sdk.i | 1 + src/tablet/tablet_impl.cc | 8 +++- 12 files changed, 178 insertions(+), 29 deletions(-) create mode 100644 src/sdk/internal/system_variable.cc create mode 100644 src/sdk/internal/system_variable.h diff --git a/hybridse/include/vm/engine.h b/hybridse/include/vm/engine.h index 09586a8b03d..7e183d43c33 100644 --- a/hybridse/include/vm/engine.h +++ b/hybridse/include/vm/engine.h @@ -429,6 +429,10 @@ class Engine { /// request row info exists in 'values' option, as a format of: /// 1. [(col1_expr, col2_expr, ... ), (...), ...] /// 2. (col1_expr, col2_expr, ... ) + // + // This function only check on request/batchrequest mode, for batch mode it does nothing. + // As for old-fashioned usage, request row does not need to appear in SQL, so it won't report + // error even request rows is empty, instead checks should performed at the very beginning of Compute. static absl::Status ExtractRequestRowsInSQL(SqlContext* ctx); std::shared_ptr GetCacheLocked(const std::string& db, diff --git a/hybridse/src/vm/engine.cc b/hybridse/src/vm/engine.cc index 5c179fb79b6..a2aaa30f305 100644 --- a/hybridse/src/vm/engine.cc +++ b/hybridse/src/vm/engine.cc @@ -462,17 +462,27 @@ int32_t RequestRunSession::Run(const uint32_t task_id, const Row& in_row, Row* o if (!sql_request_rows.empty()) { row = sql_request_rows.at(0); } - auto task = std::dynamic_pointer_cast(compile_info_) - ->get_sql_context() - .cluster_job->GetTask(task_id) - .GetRoot(); + + auto info = std::dynamic_pointer_cast(compile_info_); + auto main_task_id = info->GetClusterJob()->main_task_id(); + if (task_id == main_task_id && !info->GetRequestSchema().empty() && row.empty()) { + // a non-empty request row required but it not. + // checks only happen for a top level query, not subquery, + // since query internally may construct a empty row as input row, + // not meaning row with no columns, but row with all column values NULL + LOG(WARNING) << "request SQL requires a non-empty request row, but empty row received"; + // TODO(someone): use status + return common::StatusCode::kRunSessionError; + } + + auto task = info->get_sql_context().cluster_job->GetTask(task_id).GetRoot(); + if (nullptr == task) { LOG(WARNING) << "fail to run request plan: taskid" << task_id << " not exist!"; return -2; } DLOG(INFO) << "Request Row Run with task_id " << task_id; - RunnerContext ctx(std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job, row, - sp_name_, is_debug_); + RunnerContext ctx(info->get_sql_context().cluster_job, row, sp_name_, is_debug_); auto output = task->RunWithCache(ctx); if (!output) { LOG(WARNING) << "Run request plan output is null"; @@ -491,13 +501,24 @@ int32_t BatchRequestRunSession::Run(const std::vector& request_batch, std:: } int32_t BatchRequestRunSession::Run(const uint32_t id, const std::vector& request_batch, std::vector& output) { - std::vector<::hybridse::codec::Row>& sql_request_rows = - std::dynamic_pointer_cast(GetCompileInfo())->get_sql_context().request_rows; + auto info = std::dynamic_pointer_cast(GetCompileInfo()); + std::vector<::hybridse::codec::Row>& sql_request_rows = info->get_sql_context().request_rows; + + std::vector<::hybridse::codec::Row> rows = sql_request_rows; + if (rows.empty()) { + rows = request_batch; + } + + auto main_task_id = info->GetClusterJob()->main_task_id(); + if (id != main_task_id && !info->GetRequestSchema().empty() && rows.empty()) { + // a non-empty request row list required but it not + LOG(WARNING) << "batchrequest SQL requires a non-empty request row list, but empty row list received"; + // TODO(someone): use status + return common::StatusCode::kRunSessionError; + } - RunnerContext ctx(std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job, - sql_request_rows.empty() ? request_batch : sql_request_rows, sp_name_, is_debug_); - auto task = - std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job->GetTask(id).GetRoot(); + RunnerContext ctx(info->get_sql_context().cluster_job, rows, sp_name_, is_debug_); + auto task = info->get_sql_context().cluster_job->GetTask(id).GetRoot(); if (nullptr == task) { LOG(WARNING) << "Fail to run request plan: taskid" << id << " not exist!"; return -2; diff --git a/hybridse/src/vm/engine_context.cc b/hybridse/src/vm/engine_context.cc index 570726aa0eb..47ff39437d6 100644 --- a/hybridse/src/vm/engine_context.cc +++ b/hybridse/src/vm/engine_context.cc @@ -17,6 +17,7 @@ #include "vm/engine_context.h" #include "absl/container/flat_hash_map.h" +#include "absl/strings/ascii.h" namespace hybridse { namespace vm { @@ -61,7 +62,7 @@ std::string EngineModeName(EngineMode mode) { absl::StatusOr UnparseEngineMode(absl::string_view str) { auto& m = getModeMap(); - auto it = m.find(str); + auto it = m.find(absl::AsciiStrToLower(str)); if (it != m.end()) { return it->second; } diff --git a/src/client/tablet_client.cc b/src/client/tablet_client.cc index a1dd925fcde..cbfb794817f 100644 --- a/src/client/tablet_client.cc +++ b/src/client/tablet_client.cc @@ -73,6 +73,7 @@ bool TabletClient::Query(const std::string& db, const std::string& sql, const st } bool TabletClient::Query(const std::string& db, const std::string& sql, + hybridse::vm::EngineMode default_mode, const std::vector& parameter_types, const std::string& parameter_row, brpc::Controller* cntl, ::openmldb::api::QueryResponse* response, const bool is_debug) { @@ -80,7 +81,7 @@ bool TabletClient::Query(const std::string& db, const std::string& sql, ::openmldb::api::QueryRequest request; request.set_sql(sql); request.set_db(db); - request.set_is_batch(true); + request.set_is_batch(default_mode == hybridse::vm::kBatchMode); request.set_is_debug(is_debug); request.set_parameter_row_size(parameter_row.size()); request.set_parameter_row_slices(1); diff --git a/src/client/tablet_client.h b/src/client/tablet_client.h index 177124208fc..33188adadcc 100644 --- a/src/client/tablet_client.h +++ b/src/client/tablet_client.h @@ -64,7 +64,7 @@ class TabletClient : public Client { const openmldb::common::VersionPair& pair, std::string& msg); // NOLINT - bool Query(const std::string& db, const std::string& sql, + bool Query(const std::string& db, const std::string& sql, hybridse::vm::EngineMode default_mode, const std::vector& parameter_types, const std::string& parameter_row, brpc::Controller* cntl, ::openmldb::api::QueryResponse* response, const bool is_debug = false); diff --git a/src/sdk/internal/system_variable.cc b/src/sdk/internal/system_variable.cc new file mode 100644 index 00000000000..dc818dfb412 --- /dev/null +++ b/src/sdk/internal/system_variable.cc @@ -0,0 +1,57 @@ +/** + * Copyright (c) 2024 OpenMLDB Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sdk/internal/system_variable.h" + +#include "absl/status/status.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_join.h" +#include "absl/strings/substitute.h" + +namespace openmldb { +namespace sdk { +namespace internal { + +static SystemVariables CreateSystemVariablePresets() { + SystemVariables map = { + {"execute_mode", {"online", "request", "offline"}}, + // TODO(someone): add all + }; + return map; +} + +const SystemVariables& GetSystemVariablePresets() { + static const SystemVariables& map = *new auto(CreateSystemVariablePresets()); + return map; +} +absl::Status CheckSystemVariableSet(absl::string_view key, absl::string_view val) { + auto& presets = GetSystemVariablePresets(); + auto it = presets.find(absl::AsciiStrToLower(key)); + if (it == presets.end()) { + return absl::InvalidArgumentError(absl::Substitute("key '$0' not found as a system variable", key)); + } + + if (it->second.find(absl::AsciiStrToLower(val)) == it->second.end()) { + return absl::InvalidArgumentError( + absl::Substitute("invalid value for system variable '$0', expect one of [$1], but got $2", key, + absl::StrJoin(it->second, ", "), val)); + } + + return absl::OkStatus(); +} +} // namespace internal +} // namespace sdk +} // namespace openmldb diff --git a/src/sdk/internal/system_variable.h b/src/sdk/internal/system_variable.h new file mode 100644 index 00000000000..75edd014ab7 --- /dev/null +++ b/src/sdk/internal/system_variable.h @@ -0,0 +1,40 @@ +/** + * Copyright (c) 2024 OpenMLDB Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SRC_SDK_INTERNAL_SYSTEM_VARIABLE_H_ +#define SRC_SDK_INTERNAL_SYSTEM_VARIABLE_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" + +namespace openmldb { +namespace sdk { +namespace internal { + +using SystemVariables = absl::flat_hash_map>; + +const SystemVariables& GetSystemVariablePresets(); + +// check if the stmt 'set {key} = {val}' has a valid semantic +// key and value for system variable is case insensetive +absl::Status CheckSystemVariableSet(absl::string_view key, absl::string_view val); + +} // namespace internal +} // namespace sdk +} // namespace openmldb + +#endif // SRC_SDK_INTERNAL_SYSTEM_VARIABLE_H_ diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index 068538ba5e0..dbdd7dede9d 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -56,6 +56,7 @@ #include "sdk/base.h" #include "sdk/base_impl.h" #include "sdk/batch_request_result_set_sql.h" +#include "sdk/internal/system_variable.h" #include "sdk/job_table_helper.h" #include "sdk/node_adapter.h" #include "sdk/query_future_impl.h" @@ -1215,8 +1216,8 @@ std::shared_ptr<::hybridse::sdk::ResultSet> SQLClusterRouter::ExecuteSQLParamete cntl->set_timeout_ms(options_->request_timeout); DLOG(INFO) << "send query to tablet " << client->GetEndpoint(); auto response = std::make_shared<::openmldb::api::QueryResponse>(); - if (!client->Query(db, sql, parameter_types, parameter ? parameter->GetRow() : "", cntl.get(), response.get(), - options_->enable_debug)) { + if (!client->Query(db, sql, GetDefaultEngineMode(), parameter_types, parameter ? parameter->GetRow() : "", + cntl.get(), response.get(), options_->enable_debug)) { // rpc error is in cntl or response RPC_STATUS_AND_WARN(status, cntl, response, "Query rpc failed"); return {}; @@ -1995,7 +1996,7 @@ std::shared_ptr SQLClusterRouter::HandleSQLCmd(const h case hybridse::node::kCmdShowGlobalVariables: { std::string db = openmldb::nameserver::INFORMATION_SCHEMA_DB; std::string table = openmldb::nameserver::GLOBAL_VARIABLES; - std::string sql = "select * from " + table; + std::string sql = "select * from " + table + " CONFIG (execute_mode = 'online')"; ::hybridse::sdk::Status status; auto rs = ExecuteSQLParameterized(db, sql, std::shared_ptr(), &status); if (status.code != 0) { @@ -2884,9 +2885,7 @@ std::shared_ptr SQLClusterRouter::ExecuteSQL( } case hybridse::node::kPlanTypeFuncDef: case hybridse::node::kPlanTypeQuery: { - ::hybridse::vm::EngineMode default_mode = (!cluster_sdk_->IsClusterMode() || is_online_mode) - ? ::hybridse::vm::EngineMode::kBatchMode - : ::hybridse::vm::EngineMode::kOffline; + ::hybridse::vm::EngineMode default_mode = GetDefaultEngineMode(); // execute_mode in query config clause takes precedence auto mode = ::hybridse::vm::Engine::TryDetermineEngineMode(sql, default_mode); if (mode != ::hybridse::vm::EngineMode::kOffline) { @@ -3191,10 +3190,27 @@ std::shared_ptr SQLClusterRouter::ExecuteOfflineQuery( } } -bool SQLClusterRouter::IsOnlineMode() { +::hybridse::vm::EngineMode SQLClusterRouter::GetDefaultEngineMode() const { std::lock_guard<::openmldb::base::SpinMutex> lock(mu_); auto it = session_variables_.find("execute_mode"); - if (it != session_variables_.end() && it->second == "online") { + if (it != session_variables_.end()) { + // 1. infer from system variable + auto m = hybridse::vm::UnparseEngineMode(it->second).value_or(hybridse::vm::EngineMode::kBatchMode); + + // 2. standalone mode do not have offline + if (!cluster_sdk_->IsClusterMode() && m == hybridse::vm::kOffline) { + return hybridse::vm::kBatchMode; + } + return m; + } + + return hybridse::vm::EngineMode::kBatchMode; +} + +bool SQLClusterRouter::IsOnlineMode() const { + std::lock_guard<::openmldb::base::SpinMutex> lock(mu_); + auto it = session_variables_.find("execute_mode"); + if (it != session_variables_.end() && (it->second == "online" || it->second == "request")) { return true; } return false; @@ -3273,8 +3289,9 @@ ::hybridse::sdk::Status SQLClusterRouter::SetVariable(hybridse::node::SetPlanNod std::transform(value.begin(), value.end(), value.begin(), ::tolower); // TODO(hw): validation can be simpler if (key == "execute_mode") { - if (value != "online" && value != "offline") { - return {StatusCode::kCmdError, "the value of execute_mode must be online|offline"}; + auto s = sdk::internal::CheckSystemVariableSet(key, value); + if (!s.ok()) { + return {hybridse::common::kCmdError, s.ToString()}; } } else if (key == "enable_trace" || key == "sync_job") { if (value != "true" && value != "false") { diff --git a/src/sdk/sql_cluster_router.h b/src/sdk/sql_cluster_router.h index e917c170a14..e3159c0d1d7 100644 --- a/src/sdk/sql_cluster_router.h +++ b/src/sdk/sql_cluster_router.h @@ -274,7 +274,8 @@ class SQLClusterRouter : public SQLRouter { bool NotifyTableChange() override; - bool IsOnlineMode() override; + ::hybridse::vm::EngineMode GetDefaultEngineMode() const; + bool IsOnlineMode() const override; bool IsEnableTrace(); std::string GetDatabase() override; @@ -454,7 +455,7 @@ class SQLClusterRouter : public SQLRouter { DBSDK* cluster_sdk_; std::map>>> input_lru_cache_; - ::openmldb::base::SpinMutex mu_; + mutable ::openmldb::base::SpinMutex mu_; ::openmldb::base::Random rand_; std::atomic insert_memory_usage_limit_ = 0; // [0-100], the default value 0 means unlimited }; diff --git a/src/sdk/sql_router.h b/src/sdk/sql_router.h index f68d7d39a1c..55fd72b6b5e 100644 --- a/src/sdk/sql_router.h +++ b/src/sdk/sql_router.h @@ -220,7 +220,7 @@ class SQLRouter { virtual bool NotifyTableChange() = 0; - virtual bool IsOnlineMode() = 0; + virtual bool IsOnlineMode() const = 0; virtual std::string GetDatabase() = 0; diff --git a/src/sdk/sql_router_sdk.i b/src/sdk/sql_router_sdk.i index 15ea2b8e7c4..0186550e0a9 100644 --- a/src/sdk/sql_router_sdk.i +++ b/src/sdk/sql_router_sdk.i @@ -80,6 +80,7 @@ #include "sdk/sql_insert_row.h" #include "sdk/sql_delete_row.h" #include "sdk/table_reader.h" +#include "sdk/internal/system_variable.h" using hybridse::sdk::Schema; using hybridse::sdk::ColumnTypes; diff --git a/src/tablet/tablet_impl.cc b/src/tablet/tablet_impl.cc index 2f7544f2847..1545b96c9d4 100644 --- a/src/tablet/tablet_impl.cc +++ b/src/tablet/tablet_impl.cc @@ -1691,6 +1691,7 @@ void TabletImpl::ProcessQuery(bool is_sub, RpcController* ctrl, const openmldb:: auto mode = hybridse::vm::Engine::TryDetermineEngineMode(request->sql(), default_mode); ::hybridse::base::Status status; + // FIXME(someone): it does not handles batchrequest if (mode == hybridse::vm::EngineMode::kBatchMode) { // convert repeated openmldb:type::DataType into hybridse::codec::Schema hybridse::codec::Schema parameter_schema; @@ -5426,7 +5427,12 @@ void TabletImpl::RunRequestQuery(RpcController* ctrl, const openmldb::api::Query } if (ret != 0) { response.set_code(::openmldb::base::kSQLRunError); - response.set_msg("fail to run sql"); + if (ret == hybridse::common::StatusCode::kRunSessionError) { + // special handling + response.set_msg("request SQL requires a non-empty request row, but empty row received"); + } else { + response.set_msg("fail to run sql"); + } return; } else if (row.GetRowPtrCnt() != 1) { response.set_code(::openmldb::base::kSQLRunError); From cf86f04f6f64c98d6843179dcd872eeed6ddb251 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Wed, 26 Jun 2024 12:52:57 +0800 Subject: [PATCH 17/23] feat(udf): array_combine & array_join (#3945) * feat(udf): array_combine * feat(udf): new functions - array_combine - array_join * feat: casting arrays to array for array_combine WIP, string allocation need fix * fix: array_combine with non-string types * feat(array_combine): handle null inputs * fix(array_combine): behavior tweaks - use empty string if delimiter is null - restrict to array_combine(string, array ...) --- cases/query/udf_query.yaml | 92 +++++++++++++++++ hybridse/src/base/cartesian_product.cc | 62 +++++++++++ hybridse/src/base/cartesian_product.h | 32 ++++++ hybridse/src/codegen/array_ir_builder.cc | 115 +++++++++++++++++++++ hybridse/src/codegen/array_ir_builder.h | 14 ++- hybridse/src/codegen/ir_base_builder.cc | 20 +++- hybridse/src/codegen/string_ir_builder.cc | 12 +++ hybridse/src/codegen/string_ir_builder.h | 4 + hybridse/src/codegen/struct_ir_builder.cc | 96 ++++++++++++++++- hybridse/src/codegen/struct_ir_builder.h | 13 +++ hybridse/src/udf/default_defs/array_def.cc | 78 ++++++++++++++ hybridse/src/udf/udf.cc | 55 ++++++++++ hybridse/src/udf/udf.h | 3 +- hybridse/src/vm/jit_wrapper.cc | 7 ++ 14 files changed, 595 insertions(+), 8 deletions(-) create mode 100644 hybridse/src/base/cartesian_product.cc create mode 100644 hybridse/src/base/cartesian_product.h diff --git a/cases/query/udf_query.yaml b/cases/query/udf_query.yaml index ee9cad2d667..bc0cbe786fd 100644 --- a/cases/query/udf_query.yaml +++ b/cases/query/udf_query.yaml @@ -573,6 +573,98 @@ cases: - c1 bool data: | true, false + - id: array_join + mode: request-unsupport + sql: | + select + array_join(["1", "2"], ",") c1, + array_join(["1", "2"], "") c2, + array_join(["1", "2"], cast(null as string)) c3, + array_join(["1", NULL, "4", "5", NULL], "-") c4, + array_join(array[], ",") as c5 + expect: + columns: + - c1 string + - c2 string + - c3 string + - c4 string + - c5 string + rows: + - ["1,2", "12", "12", "1-4-5", ""] + - id: array_combine + mode: request-unsupport + sql: | + select + array_join(array_combine("-", ["1", "2"], ["3", "4"]), ",") c0, + expect: + columns: + - c0 string + rows: + - ["1-3,1-4,2-3,2-4"] + + - id: array_combine_2 + desc: array_combine casting array to array first + mode: request-unsupport + sql: | + select + array_join(array_combine("-", [1, 2], [3, 4]), ",") c0, + array_join(array_combine("-", [1, 2], array[3], ["5", "6"]), ",") c1, + array_join(array_combine("|", ["1"], [timestamp(1717171200000), timestamp("2024-06-02 12:00:00")]), ",") c2, + array_join(array_combine("|", ["1"]), ",") c3, + expect: + columns: + - c0 string + - c1 string + - c2 string + - c3 string + rows: + - ["1-3,1-4,2-3,2-4", "1-3-5,1-3-6,2-3-5,2-3-6", "1|2024-06-01 00:00:00,1|2024-06-02 12:00:00", "1"] + - id: array_combine_3 + desc: null values skipped + mode: request-unsupport + sql: | + select + array_join(array_combine("-", [1, NULL], [3, 4]), ",") c0, + array_join(array_combine("-", ARRAY[NULL], ["9", "8"]), ",") c1, + array_join(array_combine(string(NULL), ARRAY[1], ["9", "8"]), ",") c2, + expect: + columns: + - c0 string + - c1 string + - c2 string + rows: + - ["1-3,1-4", "", "19,18"] + - id: array_combine_4 + desc: construct array from table + mode: request-unsupport + inputs: + - name: t1 + columns: ["col1:int32", "std_ts:timestamp", "col2:string"] + indexs: ["index1:col1:std_ts"] + rows: + - [1, 1590115420001, "foo"] + - [2, 1590115420001, "bar"] + sql: | + select + col1, + array_join(array_combine("-", [col1, 10], [col2, "c2"]), ",") c0, + from t1 + expect: + columns: + - col1 int32 + - c0 string + rows: + - [1, "1-foo,1-c2,10-foo,10-c2"] + - [2, "2-bar,2-c2,10-bar,10-c2"] + - id: array_combine_err1 + mode: request-unsupport + sql: | + select + array_join(array_combine("-"), ",") c0, + expect: + success: false + msg: | + Fail to resolve expression: array_join(array_combine(-), ,) # ================================================================ # Map data type diff --git a/hybridse/src/base/cartesian_product.cc b/hybridse/src/base/cartesian_product.cc new file mode 100644 index 00000000000..f06d28edcdf --- /dev/null +++ b/hybridse/src/base/cartesian_product.cc @@ -0,0 +1,62 @@ +/** + * Copyright (c) 2024 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "base/cartesian_product.h" + +#include + +#include "absl/types/span.h" + +namespace hybridse { +namespace base { + +static auto cartesian_product(const std::vector>& lists) { + std::vector> result; + if (std::find_if(std::begin(lists), std::end(lists), [](auto e) -> bool { return e.size() == 0; }) != + std::end(lists)) { + return result; + } + for (auto& e : lists[0]) { + result.push_back({e}); + } + for (size_t i = 1; i < lists.size(); ++i) { + std::vector> temp; + for (auto& e : result) { + for (auto f : lists[i]) { + auto e_tmp = e; + e_tmp.push_back(f); + temp.push_back(e_tmp); + } + } + result = temp; + } + return result; +} + +std::vector> cartesian_product(absl::Span vec) { + std::vector> input; + for (auto& v : vec) { + std::vector seq(v, 0); + for (int i = 0; i < v; ++i) { + seq[i] = i; + } + input.push_back(seq); + } + return cartesian_product(input); +} + +} // namespace base +} // namespace hybridse diff --git a/hybridse/src/base/cartesian_product.h b/hybridse/src/base/cartesian_product.h new file mode 100644 index 00000000000..54d4191aba1 --- /dev/null +++ b/hybridse/src/base/cartesian_product.h @@ -0,0 +1,32 @@ +/** + * Copyright (c) 2024 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HYBRIDSE_SRC_BASE_CARTESIAN_PRODUCT_H_ +#define HYBRIDSE_SRC_BASE_CARTESIAN_PRODUCT_H_ + +#include + +#include "absl/types/span.h" + +namespace hybridse { +namespace base { + +std::vector> cartesian_product(absl::Span vec); + +} // namespace base +} // namespace hybridse + +#endif // HYBRIDSE_SRC_BASE_CARTESIAN_PRODUCT_H_ diff --git a/hybridse/src/codegen/array_ir_builder.cc b/hybridse/src/codegen/array_ir_builder.cc index 545ccb7555b..8e4a6005800 100644 --- a/hybridse/src/codegen/array_ir_builder.cc +++ b/hybridse/src/codegen/array_ir_builder.cc @@ -18,8 +18,12 @@ #include +#include "absl/strings/substitute.h" +#include "base/fe_status.h" +#include "codegen/cast_expr_ir_builder.h" #include "codegen/context.h" #include "codegen/ir_base_builder.h" +#include "codegen/string_ir_builder.h" namespace hybridse { namespace codegen { @@ -122,5 +126,116 @@ bool ArrayIRBuilder::CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** ou return true; } +absl::StatusOr ArrayIRBuilder::ExtractElement(CodeGenContextBase* ctx, const NativeValue& arr, + const NativeValue& key) const { + return absl::UnimplementedError("array extract element"); +} + +absl::StatusOr ArrayIRBuilder::NumElements(CodeGenContextBase* ctx, llvm::Value* arr) const { + llvm::Value* out = nullptr; + if (!Load(ctx->GetCurrentBlock(), arr, SZ_IDX, &out)) { + return absl::InternalError("codegen: fail to extract array size"); + } + + return out; +} + +absl::StatusOr ArrayIRBuilder::CastToArrayString(CodeGenContextBase* ctx, llvm::Value* src) { + auto sb = StructTypeIRBuilder::CreateStructTypeIRBuilder(ctx->GetModule(), src->getType()); + CHECK_ABSL_STATUSOR(sb); + + ArrayIRBuilder* src_builder = dynamic_cast(sb.value().get()); + if (!src_builder) { + return absl::InvalidArgumentError("input value not a array"); + } + + llvm::Type* src_ele_type = src_builder->element_type_; + if (IsStringPtr(src_ele_type)) { + // already array + return src; + } + + auto fields = src_builder->Load(ctx, src); + CHECK_ABSL_STATUSOR(fields); + llvm::Value* src_raws = fields.value().at(RAW_IDX); + llvm::Value* src_nulls = fields.value().at(NULL_IDX); + llvm::Value* num_elements = fields.value().at(SZ_IDX); + + llvm::Value* casted = nullptr; + if (!CreateDefault(ctx->GetCurrentBlock(), &casted)) { + return absl::InternalError("codegen error: fail to construct default array"); + } + // initialize each element + CHECK_ABSL_STATUS(Initialize(ctx, casted, {num_elements})); + + auto builder = ctx->GetBuilder(); + auto dst_fields = Load(ctx, casted); + CHECK_ABSL_STATUSOR(fields); + auto* raw_array_ptr = dst_fields.value().at(RAW_IDX); + auto* nullables_ptr = dst_fields.value().at(NULL_IDX); + + llvm::Type* idx_type = builder->getInt64Ty(); + llvm::Value* idx = builder->CreateAlloca(idx_type); + builder->CreateStore(builder->getInt64(0), idx); + CHECK_STATUS_TO_ABSL(ctx->CreateWhile( + [&](llvm::Value** cond) -> base::Status { + *cond = builder->CreateICmpSLT(builder->CreateLoad(idx_type, idx), num_elements); + return {}; + }, + [&]() -> base::Status { + llvm::Value* idx_val = builder->CreateLoad(idx_type, idx); + codegen::CastExprIRBuilder cast_builder(ctx->GetCurrentBlock()); + + llvm::Value* src_ele_value = + builder->CreateLoad(src_ele_type, builder->CreateGEP(src_ele_type, src_raws, idx_val)); + llvm::Value* dst_ele = + builder->CreateLoad(element_type_, builder->CreateGEP(element_type_, raw_array_ptr, idx_val)); + + codegen::StringIRBuilder str_builder(ctx->GetModule()); + auto s = str_builder.CastFrom(ctx->GetCurrentBlock(), src_ele_value, dst_ele); + CHECK_TRUE(s.ok(), common::kCodegenError, s.ToString()); + + builder->CreateStore( + builder->CreateLoad(builder->getInt1Ty(), builder->CreateGEP(builder->getInt1Ty(), src_nulls, idx_val)), + builder->CreateGEP(builder->getInt1Ty(), nullables_ptr, idx_val)); + + builder->CreateStore(builder->CreateAdd(idx_val, builder->getInt64(1)), idx); + return {}; + })); + + CHECK_ABSL_STATUS(Set(ctx, casted, {raw_array_ptr, nullables_ptr, num_elements})); + return casted; +} + +absl::Status ArrayIRBuilder::Initialize(CodeGenContextBase* ctx, ::llvm::Value* alloca, + absl::Span args) const { + auto* builder = ctx->GetBuilder(); + StringIRBuilder str_builder(ctx->GetModule()); + auto ele_type = str_builder.GetType(); + if (!alloca->getType()->isPointerTy() || alloca->getType()->getPointerElementType() != struct_type_ || + ele_type->getPointerTo() != element_type_) { + return absl::UnimplementedError(absl::Substitute( + "not able to Initialize array except array, got type $0", GetLlvmObjectString(alloca->getType()))); + } + if (args.size() != 1) { + // require one argument that is array size + return absl::InvalidArgumentError("initialize array requries one argument which is array size"); + } + if (!args[0]->getType()->isIntegerTy()) { + return absl::InvalidArgumentError("array size argument should be integer"); + } + auto sz = args[0]; + if (sz->getType() != builder->getInt64Ty()) { + CastExprIRBuilder cast_builder(ctx->GetCurrentBlock()); + base::Status s; + cast_builder.SafeCastNumber(sz, builder->getInt64Ty(), &sz, s); + CHECK_STATUS_TO_ABSL(s); + } + auto fn = ctx->GetModule()->getOrInsertFunction("hybridse_alloc_array_string", builder->getVoidTy(), + struct_type_->getPointerTo(), builder->getInt64Ty()); + + builder->CreateCall(fn, {alloca, sz}); + return absl::OkStatus(); +} } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/array_ir_builder.h b/hybridse/src/codegen/array_ir_builder.h index a4ab9dd0a6e..9ee522b509c 100644 --- a/hybridse/src/codegen/array_ir_builder.h +++ b/hybridse/src/codegen/array_ir_builder.h @@ -42,11 +42,21 @@ class ArrayIRBuilder : public StructTypeIRBuilder { CHECK_TRUE(false, common::kCodegenError, "casting to array un-implemented"); }; - private: - void InitStructType() override; + absl::StatusOr CastToArrayString(CodeGenContextBase* ctx, llvm::Value* src); + + absl::StatusOr ExtractElement(CodeGenContextBase* ctx, const NativeValue& arr, + const NativeValue& key) const override; + + absl::StatusOr NumElements(CodeGenContextBase* ctx, llvm::Value* arr) const override; bool CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) override; + absl::Status Initialize(CodeGenContextBase* ctx, ::llvm::Value* alloca, + absl::Span args) const override; + + private: + void InitStructType() override; + private: ::llvm::Type* element_type_ = nullptr; }; diff --git a/hybridse/src/codegen/ir_base_builder.cc b/hybridse/src/codegen/ir_base_builder.cc index 72dc7c93de7..f272cf9bda0 100644 --- a/hybridse/src/codegen/ir_base_builder.cc +++ b/hybridse/src/codegen/ir_base_builder.cc @@ -575,12 +575,12 @@ bool GetFullType(node::NodeManager* nm, ::llvm::Type* type, if (type_pointee->isStructTy()) { auto* key_type = type_pointee->getStructElementType(1); const node::TypeNode* key = nullptr; - if (key_type->isPointerTy() && !GetFullType(nm, key_type->getPointerElementType(), &key)) { + if (!key_type->isPointerTy() || !GetFullType(nm, key_type->getPointerElementType(), &key)) { return false; } const node::TypeNode* value = nullptr; auto* value_type = type_pointee->getStructElementType(2); - if (value_type->isPointerTy() && !GetFullType(nm, value_type->getPointerElementType(), &value)) { + if (!value_type->isPointerTy() || !GetFullType(nm, value_type->getPointerElementType(), &value)) { return false; } @@ -590,6 +590,22 @@ bool GetFullType(node::NodeManager* nm, ::llvm::Type* type, } return false; } + case hybridse::node::kArray: { + if (type->isPointerTy()) { + auto type_pointee = type->getPointerElementType(); + if (type_pointee->isStructTy()) { + auto* key_type = type_pointee->getStructElementType(0); + const node::TypeNode* key = nullptr; + if (!key_type->isPointerTy() || !GetFullType(nm, key_type->getPointerElementType(), &key)) { + return false; + } + + *type_node = nm->MakeNode(node::DataType::kArray, key); + return true; + } + } + return false; + } default: { *type_node = nm->MakeTypeNode(base); return true; diff --git a/hybridse/src/codegen/string_ir_builder.cc b/hybridse/src/codegen/string_ir_builder.cc index 083c907fbe4..7e677f3bd3c 100644 --- a/hybridse/src/codegen/string_ir_builder.cc +++ b/hybridse/src/codegen/string_ir_builder.cc @@ -403,5 +403,17 @@ base::Status StringIRBuilder::ConcatWS(::llvm::BasicBlock* block, *output = NativeValue::CreateWithFlag(concat_str, ret_null); return base::Status(); } +absl::Status StringIRBuilder::CastFrom(llvm::BasicBlock* block, llvm::Value* src, llvm::Value* alloca) { + if (IsStringPtr(src->getType())) { + return absl::UnimplementedError("not necessary to cast string to string"); + } + ::llvm::IRBuilder<> builder(block); + ::std::string fn_name = "string." + TypeName(src->getType()); + + auto cast_func = m_->getOrInsertFunction( + fn_name, ::llvm::FunctionType::get(builder.getVoidTy(), {src->getType(), alloca->getType()}, false)); + builder.CreateCall(cast_func, {src, alloca}); + return absl::OkStatus(); +} } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/string_ir_builder.h b/hybridse/src/codegen/string_ir_builder.h index 84f73d2822d..3c6bf986a10 100644 --- a/hybridse/src/codegen/string_ir_builder.h +++ b/hybridse/src/codegen/string_ir_builder.h @@ -37,6 +37,10 @@ class StringIRBuilder : public StructTypeIRBuilder { base::Status CastFrom(::llvm::BasicBlock* block, const NativeValue& src, NativeValue* output) override; base::Status CastFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value** output); + // casting from {in} to string ptr alloca, alloca has allocated already. + // if {in} is string ptr already, it returns error status since generally it's not necessary to call this function + absl::Status CastFrom(llvm::BasicBlock* block, llvm::Value* in, llvm::Value* alloca); + bool NewString(::llvm::BasicBlock* block, ::llvm::Value** output); bool NewString(::llvm::BasicBlock* block, const std::string& str, ::llvm::Value** output); diff --git a/hybridse/src/codegen/struct_ir_builder.cc b/hybridse/src/codegen/struct_ir_builder.cc index 4b00b052c29..d288d05a57d 100644 --- a/hybridse/src/codegen/struct_ir_builder.cc +++ b/hybridse/src/codegen/struct_ir_builder.cc @@ -16,8 +16,11 @@ #include "codegen/struct_ir_builder.h" +#include + #include "absl/status/status.h" #include "absl/strings/substitute.h" +#include "codegen/array_ir_builder.h" #include "codegen/context.h" #include "codegen/date_ir_builder.h" #include "codegen/ir_base_builder.h" @@ -70,6 +73,16 @@ absl::StatusOr> StructTypeIRBuilder::Create } break; } + case node::DataType::kArray: { + assert(ctype->IsArray() && "logic error: not a array type"); + assert(ctype->GetGenericSize() == 1 && "logic error: not a array type"); + ::llvm::Type* ele_type = nullptr; + if (!codegen::GetLlvmType(m, ctype->GetGenericType(0), &ele_type)) { + return absl::InvalidArgumentError( + absl::Substitute("not able to casting array type: $0", GetLlvmObjectString(type))); + } + return std::make_unique(m, ele_type); + } default: { break; } @@ -181,6 +194,11 @@ absl::StatusOr StructTypeIRBuilder::ExtractElement(CodeGenContextBa absl::StrCat("extract element unimplemented for ", GetLlvmObjectString(struct_type_))); } +absl::StatusOr StructTypeIRBuilder::NumElements(CodeGenContextBase* ctx, llvm::Value* arr) const { + return absl::UnimplementedError( + absl::StrCat("element size unimplemented for ", GetLlvmObjectString(struct_type_))); +} + void StructTypeIRBuilder::EnsureOK() const { assert(struct_type_ != nullptr && "filed struct_type_ uninitialized"); // it's a identified type @@ -200,9 +218,9 @@ absl::Status StructTypeIRBuilder::Set(CodeGenContextBase* ctx, ::llvm::Value* st } if (struct_value->getType()->getPointerElementType() != struct_type_) { - return absl::InvalidArgumentError(absl::Substitute("input value has different type, expect $0 but got $1", - GetLlvmObjectString(struct_type_), - GetLlvmObjectString(struct_value->getType()))); + return absl::InvalidArgumentError( + absl::Substitute("input value has different type, expect $0 but got $1", GetLlvmObjectString(struct_type_), + GetLlvmObjectString(struct_value->getType()->getPointerElementType()))); } if (members.size() != struct_type_->getNumElements()) { @@ -229,6 +247,16 @@ absl::StatusOr> StructTypeIRBuilder::Load(CodeGenConte llvm::Value* struct_ptr) const { assert(ctx != nullptr && struct_ptr != nullptr); + if (!IsStructPtr(struct_ptr->getType())) { + return absl::InvalidArgumentError( + absl::StrCat("value not a struct pointer: ", GetLlvmObjectString(struct_ptr->getType()))); + } + if (struct_ptr->getType()->getPointerElementType() != struct_type_) { + return absl::InvalidArgumentError( + absl::Substitute("input value has different type, expect $0 but got $1", GetLlvmObjectString(struct_type_), + GetLlvmObjectString(struct_ptr->getType()->getPointerElementType()))); + } + std::vector res; res.reserve(struct_type_->getNumElements()); @@ -252,5 +280,67 @@ absl::StatusOr CreateSafeNull(::llvm::BasicBlock* block, ::llvm::Ty return NativeValue(nullptr, nullptr, type); } +absl::StatusOr Combine(CodeGenContextBase* ctx, const NativeValue delimiter, + absl::Span args) { + auto builder = ctx->GetBuilder(); + + StringIRBuilder str_builder(ctx->GetModule()); + ArrayIRBuilder arr_builder(ctx->GetModule(), str_builder.GetType()->getPointerTo()); + + llvm::Value* empty_str = nullptr; + if (!str_builder.CreateDefault(ctx->GetCurrentBlock(), &empty_str)) { + return absl::InternalError("codegen error: fail to construct empty string"); + } + llvm::Value* del = builder->CreateSelect(delimiter.GetIsNull(builder), empty_str, delimiter.GetValue(builder)); + + llvm::Type* input_arr_type = arr_builder.GetType()->getPointerTo(); + llvm::Value* empty_arr = nullptr; + if (!arr_builder.CreateDefault(ctx->GetCurrentBlock(), &empty_arr)) { + return absl::InternalError("codegen error: fail to construct empty string of array"); + } + llvm::Value* input_arrays = builder->CreateAlloca(input_arr_type, builder->getInt32(args.size()), "array_data"); + node::NodeManager nm; + std::vector casted_args(args.size()); + for (int i = 0; i < args.size(); ++i) { + const node::TypeNode* tp = nullptr; + if (!GetFullType(&nm, args.at(i).GetType(), &tp)) { + return absl::InternalError("codegen error: fail to get valid type from llvm value"); + } + if (!tp->IsArray() || tp->GetGenericSize() != 1) { + return absl::InternalError("codegen error: arguments to array_combine is not ARRAY"); + } + if (!tp->GetGenericType(0)->IsString()) { + auto s = arr_builder.CastToArrayString(ctx, args.at(i).GetRaw()); + CHECK_ABSL_STATUSOR(s); + casted_args.at(i) = NativeValue::Create(s.value()); + } else { + casted_args.at(i) = args.at(i); + } + + auto safe_str_arr = + builder->CreateSelect(casted_args.at(i).GetIsNull(builder), empty_arr, casted_args.at(i).GetRaw()); + builder->CreateStore(safe_str_arr, builder->CreateGEP(input_arr_type, input_arrays, builder->getInt32(i))); + } + + ::llvm::FunctionCallee array_combine_fn = ctx->GetModule()->getOrInsertFunction( + "hybridse_array_combine", builder->getVoidTy(), str_builder.GetType()->getPointerTo(), builder->getInt32Ty(), + input_arr_type->getPointerTo(), input_arr_type); + assert(array_combine_fn); + + llvm::Value* out = builder->CreateAlloca(arr_builder.GetType()); + builder->CreateCall(array_combine_fn, { + del, // delimiter should ensure non-null + builder->getInt32(args.size()), // num of arrays + input_arrays, // ArrayRef** + out // output string + }); + + return NativeValue::Create(out); +} + +absl::Status StructTypeIRBuilder::Initialize(CodeGenContextBase* ctx, ::llvm::Value* alloca, + absl::Span args) const { + return absl::UnimplementedError(absl::StrCat("Initialize for type ", GetLlvmObjectString(struct_type_))); +} } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/struct_ir_builder.h b/hybridse/src/codegen/struct_ir_builder.h index 8529c2d7848..4c7dba732dc 100644 --- a/hybridse/src/codegen/struct_ir_builder.h +++ b/hybridse/src/codegen/struct_ir_builder.h @@ -58,6 +58,9 @@ class StructTypeIRBuilder : public TypeIRBuilder { virtual absl::StatusOr<::llvm::Value*> ConstructFromRaw(CodeGenContextBase* ctx, absl::Span<::llvm::Value* const> args) const; + virtual absl::Status Initialize(CodeGenContextBase* ctx, ::llvm::Value* alloca, + absl::Span args) const; + // Extract element value from composite data type // 1. extract from array type by index // 2. extract from struct type by field name @@ -65,6 +68,12 @@ class StructTypeIRBuilder : public TypeIRBuilder { virtual absl::StatusOr ExtractElement(CodeGenContextBase* ctx, const NativeValue& arr, const NativeValue& key) const; + // Get size of the elements inside value {arr} + // - if {arr} is array/map, return size of array/map + // - if {arr} is struct, return number of struct fields + // - otherwise report error + virtual absl::StatusOr NumElements(CodeGenContextBase* ctx, llvm::Value* arr) const; + ::llvm::Type* GetType() const; std::string GetTypeDebugString() const; @@ -100,6 +109,10 @@ class StructTypeIRBuilder : public TypeIRBuilder { // returns NativeValue{raw, is_null=true} on success, raw is ensured to be not nullptr absl::StatusOr CreateSafeNull(::llvm::BasicBlock* block, ::llvm::Type* type); +// Do the cartesian product for a list of arrry +// output a array of string, each value is a pair (A1, B2, C3...), as "A1-B2-C3-...", "-" is the delimiter +absl::StatusOr Combine(CodeGenContextBase* ctx, const NativeValue delimiter, + absl::Span args); } // namespace codegen } // namespace hybridse #endif // HYBRIDSE_SRC_CODEGEN_STRUCT_IR_BUILDER_H_ diff --git a/hybridse/src/udf/default_defs/array_def.cc b/hybridse/src/udf/default_defs/array_def.cc index b5c36bc3d7e..5f1bebfaaf6 100644 --- a/hybridse/src/udf/default_defs/array_def.cc +++ b/hybridse/src/udf/default_defs/array_def.cc @@ -15,6 +15,7 @@ */ #include "absl/strings/str_split.h" +#include "codegen/struct_ir_builder.h" #include "udf/default_udf_library.h" #include "udf/udf.h" #include "udf/udf_registry.h" @@ -101,6 +102,36 @@ void SplitString(StringRef* str, StringRef* delimeter, ArrayRef* arra } } +void array_join(ArrayRef* arr, StringRef* del, bool del_null, StringRef* out) { + int sz = 0; + for (int i = 0; i < arr->size; ++i) { + if (!arr->nullables[i]) { + if (!del_null && i > 0) { + sz += del->size_; + } + sz += arr->raw[i]->size_; + } + } + + auto buf = udf::v1::AllocManagedStringBuf(sz); + memset(buf, 0, sz); + + int32_t idx = 0; + for (int i = 0; i < arr->size; ++i) { + if (!arr->nullables[i]) { + if (!del_null && i > 0) { + memcpy(buf + idx, del->data_, del->size_); + idx += del->size_; + } + memcpy(buf + idx, arr->raw[i]->data_, arr->raw[i]->size_); + idx += arr->raw[i]->size_; + } + } + + out->data_ = buf; + out->size_ = sz; +} + // =========================================================== // // UDF Register Entry // =========================================================== // @@ -148,6 +179,53 @@ void DefaultUdfLibrary::InitArrayUdfs() { @endcode @since 0.7.0)"); + RegisterExternal("array_join") + .args, Nullable>(array_join) + .doc(R"( + @brief array_join(array, delimiter) - Concatenates the elements of the given array using the delimiter. Any null value is filtered. + + Example: + + @code{.sql} + select array_join(["1", "2"], "-"); + -- output "1-2" + @endcode + @since 0.9.2)"); + + RegisterCodeGenUdf("array_combine") + .variadic_args>( + [](UdfResolveContext* ctx, const ExprAttrNode& delimit, const std::vector& arg_attrs, + ExprAttrNode* out) -> base::Status { + CHECK_TRUE(!arg_attrs.empty(), common::kCodegenError, "at least one array required by array_combine"); + for (auto & val : arg_attrs) { + CHECK_TRUE(val.type()->IsArray(), common::kCodegenError, "argument to array_combine must be array"); + } + auto nm = ctx->node_manager(); + out->SetType(nm->MakeNode(node::kArray, nm->MakeNode(node::kVarchar))); + out->SetNullable(false); + return {}; + }, + [](codegen::CodeGenContext* ctx, codegen::NativeValue del, const std::vector& args, + const node::ExprAttrNode& return_info, codegen::NativeValue* out) -> base::Status { + auto os = codegen::Combine(ctx, del, args); + CHECK_TRUE(os.ok(), common::kCodegenError, os.status().ToString()); + *out = os.value(); + return {}; + }) + .doc(R"( + @brief array_combine(delimiter, array1, array2, ...) + + return array of strings for input array1, array2, ... doing cartesian product. Each product is joined with + {delimiter} as a string. Empty string used if {delimiter} is null. + + Example: + + @code{.sql} + select array_combine("-", ["1", "2"], ["3", "4"]); + -- output ["1-3", "1-4", "2-3", "2-4"] + @endcode + @since 0.9.2 + )"); } } // namespace udf } // namespace hybridse diff --git a/hybridse/src/udf/udf.cc b/hybridse/src/udf/udf.cc index b32d75d4ac8..e5faf25db1e 100644 --- a/hybridse/src/udf/udf.cc +++ b/hybridse/src/udf/udf.cc @@ -21,11 +21,13 @@ #include #include +#include #include "absl/strings/ascii.h" #include "absl/strings/str_replace.h" #include "absl/time/civil_time.h" #include "absl/time/time.h" +#include "base/cartesian_product.h" #include "base/iterator.h" #include "boost/date_time/gregorian/conversion.hpp" #include "boost/date_time/gregorian/parsers.hpp" @@ -1413,6 +1415,59 @@ void printLog(const char* fmt) { } } +// each variadic arg is ArrayRef* +void array_combine(codec::StringRef *del, int32_t cnt, ArrayRef **data, + ArrayRef *out) { + std::vector arr_szs(cnt, 0); + for (int32_t i = 0; i < cnt; ++i) { + auto arr = data[i]; + arr_szs.at(i) = arr->size; + } + + // cal cartesian products + auto products = hybridse::base::cartesian_product(arr_szs); + + auto real_sz = products.size(); + v1::AllocManagedArray(out, products.size()); + + for (int prod_idx = 0; prod_idx < products.size(); ++prod_idx) { + auto &prod = products.at(prod_idx); + int32_t sz = 0; + for (int i = 0; i < prod.size(); ++i) { + if (!data[i]->nullables[prod.at(i)]) { + // delimiter would be empty string if null + if (i > 0) { + sz += del->size_; + } + sz += data[i]->raw[prod.at(i)]->size_; + } else { + // null exists in current product + // the only option now is to skip + real_sz--; + continue; + } + } + auto buf = v1::AllocManagedStringBuf(sz); + int32_t idx = 0; + for (int i = 0; i < prod.size(); ++i) { + if (!data[i]->nullables[prod.at(i)]) { + if (i > 0 && del->size_ > 0) { + memcpy(buf + idx, del->data_, del->size_); + idx += del->size_; + } + memcpy(buf + idx, data[i]->raw[prod.at(i)]->data_, data[i]->raw[prod.at(i)]->size_); + idx += data[i]->raw[prod.at(i)]->size_; + } + } + + out->nullables[prod_idx] = false; + out->raw[prod_idx]->data_ = buf; + out->raw[prod_idx]->size_ = sz; + } + + out->size = real_sz; +} + } // namespace v1 bool RegisterMethod(UdfLibrary *lib, const std::string &fn_name, hybridse::node::TypeNode *ret, diff --git a/hybridse/src/udf/udf.h b/hybridse/src/udf/udf.h index b7f222433a7..480e4f89f3c 100644 --- a/hybridse/src/udf/udf.h +++ b/hybridse/src/udf/udf.h @@ -520,7 +520,8 @@ void hex(StringRef *str, StringRef *output); void unhex(StringRef *str, StringRef *output, bool* is_null); void printLog(const char* fmt); - +void array_combine(codec::StringRef *del, int32_t cnt, ArrayRef **data, + ArrayRef *out); } // namespace v1 /// \brief register native udf related methods into given UdfLibrary `lib` diff --git a/hybridse/src/vm/jit_wrapper.cc b/hybridse/src/vm/jit_wrapper.cc index d4b4df1b01d..220a7d93085 100644 --- a/hybridse/src/vm/jit_wrapper.cc +++ b/hybridse/src/vm/jit_wrapper.cc @@ -17,6 +17,8 @@ #include #include + +#include "base/cartesian_product.h" #include "glog/logging.h" #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/Orc/CompileUtils.h" @@ -251,6 +253,11 @@ void InitBuiltinJitSymbols(HybridSeJitWrapper* jit) { "fmod", reinterpret_cast( static_cast(&fmod))); jit->AddExternalFunction("fmodf", reinterpret_cast(&fmodf)); + + // cartesian product + jit->AddExternalFunction("hybridse_array_combine", reinterpret_cast(&hybridse::udf::v1::array_combine)); + jit->AddExternalFunction("hybridse_alloc_array_string", + reinterpret_cast(&hybridse::udf::v1::AllocManagedArray)); } } // namespace vm From e3da2a6bbfa45792ae925e27667dd38af8cd3b99 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Wed, 26 Jun 2024 16:05:58 +0800 Subject: [PATCH 18/23] feat: support batchrequest in ProcessQuery (#3938) --- src/tablet/tablet_impl.cc | 239 ++++++++++++++++++++++++-------------- 1 file changed, 151 insertions(+), 88 deletions(-) diff --git a/src/tablet/tablet_impl.cc b/src/tablet/tablet_impl.cc index 1545b96c9d4..230b5c46a09 100644 --- a/src/tablet/tablet_impl.cc +++ b/src/tablet/tablet_impl.cc @@ -20,6 +20,7 @@ #include #include #include +#include "vm/sql_compiler.h" #ifdef DISALLOW_COPY_AND_ASSIGN #undef DISALLOW_COPY_AND_ASSIGN #endif @@ -1691,98 +1692,127 @@ void TabletImpl::ProcessQuery(bool is_sub, RpcController* ctrl, const openmldb:: auto mode = hybridse::vm::Engine::TryDetermineEngineMode(request->sql(), default_mode); ::hybridse::base::Status status; - // FIXME(someone): it does not handles batchrequest - if (mode == hybridse::vm::EngineMode::kBatchMode) { - // convert repeated openmldb:type::DataType into hybridse::codec::Schema - hybridse::codec::Schema parameter_schema; - for (int i = 0; i < request->parameter_types().size(); i++) { - auto column = parameter_schema.Add(); - hybridse::type::Type hybridse_type; - - if (!openmldb::schema::SchemaAdapter::ConvertType(request->parameter_types(i), &hybridse_type)) { - response->set_msg("Invalid parameter type: " + - openmldb::type::DataType_Name(request->parameter_types(i))); - response->set_code(::openmldb::base::kSQLCompileError); - return; + switch (mode) { + case hybridse::vm::EngineMode::kBatchMode: { + // convert repeated openmldb:type::DataType into hybridse::codec::Schema + hybridse::codec::Schema parameter_schema; + for (int i = 0; i < request->parameter_types().size(); i++) { + auto column = parameter_schema.Add(); + hybridse::type::Type hybridse_type; + + if (!openmldb::schema::SchemaAdapter::ConvertType(request->parameter_types(i), &hybridse_type)) { + response->set_msg("Invalid parameter type: " + + openmldb::type::DataType_Name(request->parameter_types(i))); + response->set_code(::openmldb::base::kSQLCompileError); + return; + } + column->set_type(hybridse_type); } - column->set_type(hybridse_type); - } - ::hybridse::vm::BatchRunSession session; - if (request->is_debug()) { - session.EnableDebug(); - } - session.SetParameterSchema(parameter_schema); - { - bool ok = engine_->Get(request->sql(), request->db(), session, status); - if (!ok) { - response->set_msg(status.msg); - response->set_code(::openmldb::base::kSQLCompileError); - DLOG(WARNING) << "fail to compile sql " << request->sql() << ", message: " << status.msg; - return; + ::hybridse::vm::BatchRunSession session; + if (request->is_debug()) { + session.EnableDebug(); + } + session.SetParameterSchema(parameter_schema); + { + bool ok = engine_->Get(request->sql(), request->db(), session, status); + if (!ok) { + response->set_msg(status.msg); + response->set_code(::openmldb::base::kSQLCompileError); + DLOG(WARNING) << "fail to compile sql " << request->sql() << ", message: " << status.msg; + return; + } } - } - ::hybridse::codec::Row parameter_row; - auto& request_buf = static_cast(ctrl)->request_attachment(); - if (request->parameter_row_size() > 0 && - !codec::DecodeRpcRow(request_buf, 0, request->parameter_row_size(), request->parameter_row_slices(), - ¶meter_row)) { - response->set_code(::openmldb::base::kSQLRunError); - response->set_msg("fail to decode parameter row"); - return; - } - std::vector<::hybridse::codec::Row> output_rows; - int32_t run_ret = session.Run(parameter_row, output_rows); - if (run_ret != 0) { - response->set_msg(status.msg); - response->set_code(::openmldb::base::kSQLRunError); - DLOG(WARNING) << "fail to run sql: " << request->sql(); - return; - } - uint32_t byte_size = 0; - uint32_t count = 0; - for (auto& output_row : output_rows) { - if (FLAGS_scan_max_bytes_size > 0 && byte_size > FLAGS_scan_max_bytes_size) { - LOG(WARNING) << "reach the max byte size " << FLAGS_scan_max_bytes_size << " truncate result"; - response->set_schema(session.GetEncodedSchema()); - response->set_byte_size(byte_size); - response->set_count(count); - response->set_code(::openmldb::base::kOk); + ::hybridse::codec::Row parameter_row; + auto& request_buf = static_cast(ctrl)->request_attachment(); + if (request->parameter_row_size() > 0 && + !codec::DecodeRpcRow(request_buf, 0, request->parameter_row_size(), request->parameter_row_slices(), + ¶meter_row)) { + response->set_code(::openmldb::base::kSQLRunError); + response->set_msg("fail to decode parameter row"); + return; + } + std::vector<::hybridse::codec::Row> output_rows; + int32_t run_ret = session.Run(parameter_row, output_rows); + if (run_ret != 0) { + response->set_msg(status.msg); + response->set_code(::openmldb::base::kSQLRunError); + DLOG(WARNING) << "fail to run sql: " << request->sql(); return; } - byte_size += output_row.size(); - buf->append(reinterpret_cast(output_row.buf()), output_row.size()); - count += 1; + uint32_t byte_size = 0; + uint32_t count = 0; + for (auto& output_row : output_rows) { + if (FLAGS_scan_max_bytes_size > 0 && byte_size > FLAGS_scan_max_bytes_size) { + LOG(WARNING) << "reach the max byte size " << FLAGS_scan_max_bytes_size << " truncate result"; + response->set_schema(session.GetEncodedSchema()); + response->set_byte_size(byte_size); + response->set_count(count); + response->set_code(::openmldb::base::kOk); + return; + } + byte_size += output_row.size(); + buf->append(reinterpret_cast(output_row.buf()), output_row.size()); + count += 1; + } + response->set_schema(session.GetEncodedSchema()); + response->set_byte_size(byte_size); + response->set_count(count); + response->set_code(::openmldb::base::kOk); + DLOG(INFO) << "handle batch sql " << request->sql() << " with record cnt " << count << " byte size " + << byte_size; + break; } - response->set_schema(session.GetEncodedSchema()); - response->set_byte_size(byte_size); - response->set_count(count); - response->set_code(::openmldb::base::kOk); - DLOG(INFO) << "handle batch sql " << request->sql() << " with record cnt " << count << " byte size " - << byte_size; - } else { - ::hybridse::vm::RequestRunSession session; - if (request->is_debug()) { - session.EnableDebug(); - } - if (request->is_procedure()) { - const std::string& db_name = request->db(); - const std::string& sp_name = request->sp_name(); - std::shared_ptr request_compile_info; - { - hybridse::base::Status status; - request_compile_info = sp_cache_->GetRequestInfo(db_name, sp_name, status); - if (!status.isOK()) { - response->set_code(::openmldb::base::ReturnCode::kProcedureNotFound); + case hybridse::vm::kRequestMode: { + ::hybridse::vm::RequestRunSession session; + if (request->is_debug()) { + session.EnableDebug(); + } + if (request->is_procedure()) { + const std::string& db_name = request->db(); + const std::string& sp_name = request->sp_name(); + std::shared_ptr request_compile_info; + { + hybridse::base::Status status; + request_compile_info = sp_cache_->GetRequestInfo(db_name, sp_name, status); + if (!status.isOK()) { + response->set_code(::openmldb::base::ReturnCode::kProcedureNotFound); + response->set_msg(status.msg); + PDLOG(WARNING, status.msg.c_str()); + return; + } + } + session.SetCompileInfo(request_compile_info); + session.SetSpName(sp_name); + RunRequestQuery(ctrl, *request, session, *response, *buf); + } else { + bool ok = engine_->Get(request->sql(), request->db(), session, status); + if (!ok || session.GetCompileInfo() == nullptr) { response->set_msg(status.msg); - PDLOG(WARNING, status.msg.c_str()); + response->set_code(::openmldb::base::kSQLCompileError); + DLOG(WARNING) << "fail to compile sql in request mode:\n" << request->sql(); return; } + RunRequestQuery(ctrl, *request, session, *response, *buf); + } + const std::string& sql = session.GetCompileInfo()->GetSql(); + if (response->code() != ::openmldb::base::kOk) { + DLOG(WARNING) << "fail to run sql " << sql << " error msg: " << response->msg(); + } else { + DLOG(INFO) << "handle request sql " << sql; + } + break; + } + case hybridse::vm::kBatchRequestMode: { + // we support a simplified batch request query here + // not procedure + // no parameter input or bachrequst row + // batchrequest row must specified in CONFIG (values = ...) + ::hybridse::base::Status status; + ::hybridse::vm::BatchRequestRunSession session; + if (request->is_debug()) { + session.EnableDebug(); } - session.SetCompileInfo(request_compile_info); - session.SetSpName(sp_name); - RunRequestQuery(ctrl, *request, session, *response, *buf); - } else { bool ok = engine_->Get(request->sql(), request->db(), session, status); if (!ok || session.GetCompileInfo() == nullptr) { response->set_msg(status.msg); @@ -1790,13 +1820,46 @@ void TabletImpl::ProcessQuery(bool is_sub, RpcController* ctrl, const openmldb:: DLOG(WARNING) << "fail to compile sql in request mode:\n" << request->sql(); return; } - RunRequestQuery(ctrl, *request, session, *response, *buf); + auto info = std::dynamic_pointer_cast(session.GetCompileInfo()); + if (info && info->get_sql_context().request_rows.empty()) { + response->set_msg("batch request values must specified in SQL CONFIG (values = [...])"); + response->set_code(::openmldb::base::kSQLCompileError); + return; + } + std::vector<::hybridse::codec::Row> output_rows; + std::vector<::hybridse::codec::Row> empty_inputs; + int32_t run_ret = session.Run(empty_inputs, output_rows); + if (run_ret != 0) { + response->set_msg(status.msg); + response->set_code(::openmldb::base::kSQLRunError); + DLOG(WARNING) << "fail to run batchrequest sql: " << request->sql(); + return; + } + uint32_t byte_size = 0; + uint32_t count = 0; + for (auto& output_row : output_rows) { + if (FLAGS_scan_max_bytes_size > 0 && byte_size > FLAGS_scan_max_bytes_size) { + LOG(WARNING) << "reach the max byte size " << FLAGS_scan_max_bytes_size << " truncate result"; + response->set_schema(session.GetEncodedSchema()); + response->set_byte_size(byte_size); + response->set_count(count); + response->set_code(::openmldb::base::kOk); + return; + } + byte_size += output_row.size(); + buf->append(reinterpret_cast(output_row.buf()), output_row.size()); + count += 1; + } + response->set_schema(session.GetEncodedSchema()); + response->set_byte_size(byte_size); + response->set_count(count); + response->set_code(::openmldb::base::kOk); + break; } - const std::string& sql = session.GetCompileInfo()->GetSql(); - if (response->code() != ::openmldb::base::kOk) { - DLOG(WARNING) << "fail to run sql " << sql << " error msg: " << response->msg(); - } else { - DLOG(INFO) << "handle request sql " << sql; + default: { + response->set_msg("un-implemented execute_mode: " + hybridse::vm::EngineModeName(mode)); + response->set_code(::openmldb::base::kSQLCompileError); + break; } } } From 2a739528e2af6c2702590be219c25beefac10f9c Mon Sep 17 00:00:00 2001 From: oh2024 <162292688+oh2024@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:50:58 +0800 Subject: [PATCH 19/23] feat: user authz (#3941) * feat: change user table to match mysql * feat: support user authz * fix: cean up created users --- hybridse/include/node/node_enum.h | 4 + hybridse/include/node/plan_node.h | 61 ++++++++++++ hybridse/include/node/sql_node.h | 58 +++++++++++ hybridse/src/plan/planner.cc | 16 +++ hybridse/src/planv2/ast_node_converter.cc | 90 +++++++++++++++++ hybridse/src/planv2/ast_node_converter.h | 6 ++ src/auth/user_access_manager.cc | 41 ++++++-- src/auth/user_access_manager.h | 10 +- src/base/status.h | 3 +- src/client/ns_client.cc | 29 ++++++ src/client/ns_client.h | 5 + src/cmd/sql_cmd_test.cc | 116 +++++++++++++++++++++- src/nameserver/name_server_impl.cc | 72 +++++++++++--- src/nameserver/name_server_impl.h | 5 +- src/nameserver/system_table.h | 18 ++-- src/proto/name_server.proto | 19 ++++ src/sdk/sql_cluster_router.cc | 24 +++++ 17 files changed, 543 insertions(+), 34 deletions(-) diff --git a/hybridse/include/node/node_enum.h b/hybridse/include/node/node_enum.h index 7b189aa6aac..eea2bd9a953 100644 --- a/hybridse/include/node/node_enum.h +++ b/hybridse/include/node/node_enum.h @@ -98,6 +98,8 @@ enum SqlNodeType { kColumnSchema, kCreateUserStmt, kAlterUserStmt, + kGrantStmt, + kRevokeStmt, kCallStmt, kSqlNodeTypeLast, // debug type kVariadicUdfDef, @@ -347,6 +349,8 @@ enum PlanType { kPlanTypeShow, kPlanTypeCreateUser, kPlanTypeAlterUser, + kPlanTypeGrant, + kPlanTypeRevoke, kPlanTypeCallStmt, kUnknowPlan = -1, }; diff --git a/hybridse/include/node/plan_node.h b/hybridse/include/node/plan_node.h index ec82b6a586f..0e5683c4702 100644 --- a/hybridse/include/node/plan_node.h +++ b/hybridse/include/node/plan_node.h @@ -739,6 +739,67 @@ class CreateUserPlanNode : public LeafPlanNode { const std::shared_ptr options_; }; +class GrantPlanNode : public LeafPlanNode { + public: + explicit GrantPlanNode(std::optional target_type, std::string database, std::string target, + std::vector privileges, bool is_all_privileges, + std::vector grantees, bool with_grant_option) + : LeafPlanNode(kPlanTypeGrant), + target_type_(target_type), + database_(database), + target_(target), + privileges_(privileges), + is_all_privileges_(is_all_privileges), + grantees_(grantees), + with_grant_option_(with_grant_option) {} + ~GrantPlanNode() = default; + const std::vector Privileges() const { return privileges_; } + const std::vector Grantees() const { return grantees_; } + const std::string Database() const { return database_; } + const std::string Target() const { return target_; } + const std::optional TargetType() const { return target_type_; } + const bool IsAllPrivileges() const { return is_all_privileges_; } + const bool WithGrantOption() const { return with_grant_option_; } + + private: + std::optional target_type_; + std::string database_; + std::string target_; + std::vector privileges_; + bool is_all_privileges_; + std::vector grantees_; + bool with_grant_option_; +}; + +class RevokePlanNode : public LeafPlanNode { + public: + explicit RevokePlanNode(std::optional target_type, std::string database, std::string target, + std::vector privileges, bool is_all_privileges, + std::vector grantees) + : LeafPlanNode(kPlanTypeRevoke), + target_type_(target_type), + database_(database), + target_(target), + privileges_(privileges), + is_all_privileges_(is_all_privileges), + grantees_(grantees) {} + ~RevokePlanNode() = default; + const std::vector Privileges() const { return privileges_; } + const std::vector Grantees() const { return grantees_; } + const std::string Database() const { return database_; } + const std::string Target() const { return target_; } + const std::optional TargetType() const { return target_type_; } + const bool IsAllPrivileges() const { return is_all_privileges_; } + + private: + std::optional target_type_; + std::string database_; + std::string target_; + std::vector privileges_; + bool is_all_privileges_; + std::vector grantees_; +}; + class AlterUserPlanNode : public LeafPlanNode { public: explicit AlterUserPlanNode(const std::string& name, bool if_exists, std::shared_ptr options) diff --git a/hybridse/include/node/sql_node.h b/hybridse/include/node/sql_node.h index 96ea7a94163..52542426c2a 100644 --- a/hybridse/include/node/sql_node.h +++ b/hybridse/include/node/sql_node.h @@ -2421,6 +2421,64 @@ class AlterUserNode : public SqlNode { const std::shared_ptr options_; }; +class GrantNode : public SqlNode { + public: + explicit GrantNode(std::optional target_type, std::string database, std::string target, + std::vector privileges, bool is_all_privileges, std::vector grantees, + bool with_grant_option) + : SqlNode(kGrantStmt, 0, 0), + target_type_(target_type), + database_(database), + target_(target), + privileges_(privileges), + is_all_privileges_(is_all_privileges), + grantees_(grantees), + with_grant_option_(with_grant_option) {} + const std::vector Privileges() const { return privileges_; } + const std::vector Grantees() const { return grantees_; } + const std::string Database() const { return database_; } + const std::string Target() const { return target_; } + const std::optional TargetType() const { return target_type_; } + const bool IsAllPrivileges() const { return is_all_privileges_; } + const bool WithGrantOption() const { return with_grant_option_; } + + private: + std::optional target_type_; + std::string database_; + std::string target_; + std::vector privileges_; + bool is_all_privileges_; + std::vector grantees_; + bool with_grant_option_; +}; + +class RevokeNode : public SqlNode { + public: + explicit RevokeNode(std::optional target_type, std::string database, std::string target, + std::vector privileges, bool is_all_privileges, std::vector grantees) + : SqlNode(kRevokeStmt, 0, 0), + target_type_(target_type), + database_(database), + target_(target), + privileges_(privileges), + is_all_privileges_(is_all_privileges), + grantees_(grantees) {} + const std::vector Privileges() const { return privileges_; } + const std::vector Grantees() const { return grantees_; } + const std::string Database() const { return database_; } + const std::string Target() const { return target_; } + const std::optional TargetType() const { return target_type_; } + const bool IsAllPrivileges() const { return is_all_privileges_; } + + private: + std::optional target_type_; + std::string database_; + std::string target_; + std::vector privileges_; + bool is_all_privileges_; + std::vector grantees_; +}; + class ExplainNode : public SqlNode { public: explicit ExplainNode(const QueryNode *query, node::ExplainType explain_type) diff --git a/hybridse/src/plan/planner.cc b/hybridse/src/plan/planner.cc index b2a57b4128c..3a3984c9b16 100644 --- a/hybridse/src/plan/planner.cc +++ b/hybridse/src/plan/planner.cc @@ -768,6 +768,22 @@ base::Status SimplePlanner::CreatePlanTree(const NodePointVector &parser_trees, plan_trees.push_back(create_user_plan_node); break; } + case ::hybridse::node::kGrantStmt: { + auto node = dynamic_cast(parser_tree); + auto grant_plan_node = node_manager_->MakeNode( + node->TargetType(), node->Database(), node->Target(), node->Privileges(), node->IsAllPrivileges(), + node->Grantees(), node->WithGrantOption()); + plan_trees.push_back(grant_plan_node); + break; + } + case ::hybridse::node::kRevokeStmt: { + auto node = dynamic_cast(parser_tree); + auto revoke_plan_node = node_manager_->MakeNode( + node->TargetType(), node->Database(), node->Target(), node->Privileges(), node->IsAllPrivileges(), + node->Grantees()); + plan_trees.push_back(revoke_plan_node); + break; + } case ::hybridse::node::kAlterUserStmt: { auto node = dynamic_cast(parser_tree); auto alter_user_plan_node = node_manager_->MakeNode(node->Name(), diff --git a/hybridse/src/planv2/ast_node_converter.cc b/hybridse/src/planv2/ast_node_converter.cc index a8453e1221c..23e56924ae2 100644 --- a/hybridse/src/planv2/ast_node_converter.cc +++ b/hybridse/src/planv2/ast_node_converter.cc @@ -24,6 +24,7 @@ #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/types/span.h" +#include "ast_node_converter.h" #include "base/fe_status.h" #include "node/sql_node.h" #include "udf/udf.h" @@ -725,6 +726,20 @@ base::Status ConvertStatement(const zetasql::ASTStatement* statement, node::Node *output = create_user_node; break; } + case zetasql::AST_GRANT_STATEMENT: { + const zetasql::ASTGrantStatement* grant_stmt = statement->GetAsOrNull(); + node::GrantNode* grant_node = nullptr; + CHECK_STATUS(ConvertGrantStatement(grant_stmt, node_manager, &grant_node)) + *output = grant_node; + break; + } + case zetasql::AST_REVOKE_STATEMENT: { + const zetasql::ASTRevokeStatement* revoke_stmt = statement->GetAsOrNull(); + node::RevokeNode* revoke_node = nullptr; + CHECK_STATUS(ConvertRevokeStatement(revoke_stmt, node_manager, &revoke_node)) + *output = revoke_node; + break; + } case zetasql::AST_ALTER_USER_STATEMENT: { const zetasql::ASTAlterUserStatement* alter_user_stmt = statement->GetAsOrNull(); @@ -2133,6 +2148,81 @@ base::Status ConvertAlterUserStatement(const zetasql::ASTAlterUserStatement* roo return base::Status::OK(); } +base::Status ConvertGrantStatement(const zetasql::ASTGrantStatement* root, node::NodeManager* node_manager, + node::GrantNode** output) { + CHECK_TRUE(root != nullptr, common::kSqlAstError, "not an ASTGrantStatement"); + std::vector target_path; + CHECK_STATUS(AstPathExpressionToStringList(root->target_path(), target_path)); + std::optional target_type = std::nullopt; + if (root->target_type() != nullptr) { + target_type = root->target_type()->GetAsString(); + } + + std::vector privileges; + std::vector grantees; + for (auto privilege : root->privileges()->privileges()) { + if (privilege == nullptr) { + continue; + } + + auto privilege_action = privilege->privilege_action(); + if (privilege_action != nullptr) { + privileges.push_back(privilege_action->GetAsString()); + } + } + + for (auto grantee : root->grantee_list()->grantee_list()) { + if (grantee == nullptr) { + continue; + } + + std::string grantee_str; + CHECK_STATUS(AstStringLiteralToString(grantee, &grantee_str)); + grantees.push_back(grantee_str); + } + *output = node_manager->MakeNode(target_type, target_path.at(0), target_path.at(1), privileges, + root->privileges()->is_all_privileges(), grantees, + root->with_grant_option()); + return base::Status::OK(); +} + +base::Status ConvertRevokeStatement(const zetasql::ASTRevokeStatement* root, node::NodeManager* node_manager, + node::RevokeNode** output) { + CHECK_TRUE(root != nullptr, common::kSqlAstError, "not an ASTRevokeStatement"); + std::vector target_path; + CHECK_STATUS(AstPathExpressionToStringList(root->target_path(), target_path)); + std::optional target_type = std::nullopt; + if (root->target_type() != nullptr) { + target_type = root->target_type()->GetAsString(); + } + + std::vector privileges; + std::vector grantees; + for (auto privilege : root->privileges()->privileges()) { + if (privilege == nullptr) { + continue; + } + + auto privilege_action = privilege->privilege_action(); + if (privilege_action != nullptr) { + privileges.push_back(privilege_action->GetAsString()); + } + } + + for (auto grantee : root->grantee_list()->grantee_list()) { + if (grantee == nullptr) { + continue; + } + + std::string grantee_str; + CHECK_STATUS(AstStringLiteralToString(grantee, &grantee_str)); + grantees.push_back(grantee_str); + } + *output = node_manager->MakeNode(target_type, target_path.at(0), target_path.at(1), privileges, + root->privileges()->is_all_privileges(), grantees); + return base::Status::OK(); +} + base::Status ConvertCreateIndexStatement(const zetasql::ASTCreateIndexStatement* root, node::NodeManager* node_manager, node::CreateIndexNode** output) { CHECK_TRUE(nullptr != root, common::kSqlAstError, "not an ASTCreateIndexStatement") diff --git a/hybridse/src/planv2/ast_node_converter.h b/hybridse/src/planv2/ast_node_converter.h index 631569156d2..edc0fb60c50 100644 --- a/hybridse/src/planv2/ast_node_converter.h +++ b/hybridse/src/planv2/ast_node_converter.h @@ -72,6 +72,12 @@ base::Status ConvertCreateUserStatement(const zetasql::ASTCreateUserStatement* r base::Status ConvertAlterUserStatement(const zetasql::ASTAlterUserStatement* root, node::NodeManager* node_manager, node::AlterUserNode** output); +base::Status ConvertGrantStatement(const zetasql::ASTGrantStatement* root, node::NodeManager* node_manager, + node::GrantNode** output); + +base::Status ConvertRevokeStatement(const zetasql::ASTRevokeStatement* root, node::NodeManager* node_manager, + node::RevokeNode** output); + base::Status ConvertQueryNode(const zetasql::ASTQuery* root, node::NodeManager* node_manager, node::QueryNode** output); base::Status ConvertQueryExpr(const zetasql::ASTQueryExpression* query_expr, node::NodeManager* node_manager, diff --git a/src/auth/user_access_manager.cc b/src/auth/user_access_manager.cc index d668a7dc497..1f354998ef3 100644 --- a/src/auth/user_access_manager.cc +++ b/src/auth/user_access_manager.cc @@ -47,7 +47,7 @@ void UserAccessManager::StopSyncTask() { void UserAccessManager::SyncWithDB() { if (auto it_pair = user_table_iterator_factory_(::openmldb::nameserver::USER_INFO_NAME); it_pair) { - auto new_user_map = std::make_unique>(); + auto new_user_map = std::make_unique>(); auto it = it_pair->first.get(); it->SeekToFirst(); while (it->Valid()) { @@ -56,13 +56,18 @@ void UserAccessManager::SyncWithDB() { auto size = it->GetValue().size(); codec::RowView row_view(*it_pair->second.get(), buf, size); std::string host, user, password; + std::string privilege_level_str; row_view.GetStrValue(0, &host); row_view.GetStrValue(1, &user); row_view.GetStrValue(2, &password); + row_view.GetStrValue(5, &privilege_level_str); + openmldb::nameserver::PrivilegeLevel privilege_level; + ::openmldb::nameserver::PrivilegeLevel_Parse(privilege_level_str, &privilege_level); + UserRecord user_record = {password, privilege_level}; if (host == "%") { - new_user_map->emplace(user, password); + new_user_map->emplace(user, user_record); } else { - new_user_map->emplace(FormUserHost(user, host), password); + new_user_map->emplace(FormUserHost(user, host), user_record); } it->Next(); } @@ -70,12 +75,36 @@ void UserAccessManager::SyncWithDB() { } } +std::optional UserAccessManager::GetUserPassword(const std::string& host, const std::string& user) { + if (auto user_record = user_map_.Get(FormUserHost(user, host)); user_record.has_value()) { + return user_record.value().password; + } else if (auto stored_password = user_map_.Get(user); stored_password.has_value()) { + return stored_password.value().password; + } else { + return std::nullopt; + } +} + bool UserAccessManager::IsAuthenticated(const std::string& host, const std::string& user, const std::string& password) { - if (auto stored_password = user_map_.Get(FormUserHost(user, host)); stored_password.has_value()) { - return stored_password.value() == password; + if (auto user_record = user_map_.Get(FormUserHost(user, host)); user_record.has_value()) { + return user_record.value().password == password; } else if (auto stored_password = user_map_.Get(user); stored_password.has_value()) { - return stored_password.value() == password; + return stored_password.value().password == password; } return false; } + +::openmldb::nameserver::PrivilegeLevel UserAccessManager::GetPrivilegeLevel(const std::string& user_at_host) { + std::size_t at_pos = user_at_host.find('@'); + if (at_pos != std::string::npos) { + std::string user = user_at_host.substr(0, at_pos); + std::string host = user_at_host.substr(at_pos + 1); + if (auto user_record = user_map_.Get(FormUserHost(user, host)); user_record.has_value()) { + return user_record.value().privilege_level; + } else if (auto stored_password = user_map_.Get(user); stored_password.has_value()) { + return stored_password.value().privilege_level; + } + } + return ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE; +} } // namespace openmldb::auth diff --git a/src/auth/user_access_manager.h b/src/auth/user_access_manager.h index 996efc326c4..9de6890f93a 100644 --- a/src/auth/user_access_manager.h +++ b/src/auth/user_access_manager.h @@ -26,9 +26,15 @@ #include #include "catalog/distribute_iterator.h" +#include "proto/name_server.pb.h" #include "refreshable_map.h" namespace openmldb::auth { +struct UserRecord { + std::string password; + ::openmldb::nameserver::PrivilegeLevel privilege_level; +}; + class UserAccessManager { public: using IteratorFactory = std::function GetUserPassword(const std::string& host, const std::string& user); private: IteratorFactory user_table_iterator_factory_; - RefreshableMap user_map_; + RefreshableMap user_map_; std::thread sync_task_thread_; std::promise stop_promise_; void StartSyncTask(); diff --git a/src/base/status.h b/src/base/status.h index c7e5ec75198..a2da254e78e 100644 --- a/src/base/status.h +++ b/src/base/status.h @@ -186,7 +186,8 @@ enum ReturnCode { kRPCError = 1004, // brpc controller error // auth - kFlushPrivilegesFailed = 1100 // brpc controller error + kFlushPrivilegesFailed = 1100, // brpc controller error + kNotAuthorized = 1101 // brpc controller error }; struct Status { diff --git a/src/client/ns_client.cc b/src/client/ns_client.cc index cdeef07e521..9a4c6f4df6d 100644 --- a/src/client/ns_client.cc +++ b/src/client/ns_client.cc @@ -317,6 +317,35 @@ bool NsClient::PutUser(const std::string& host, const std::string& name, const s return false; } +bool NsClient::PutPrivilege(const std::optional target_type, const std::string database, + const std::string target, const std::vector privileges, + const bool is_all_privileges, const std::vector grantees, + const ::openmldb::nameserver::PrivilegeLevel privilege_level) { + ::openmldb::nameserver::PutPrivilegeRequest request; + if (target_type.has_value()) { + request.set_target_type(target_type.value()); + } + request.set_database(database); + request.set_target(target); + for (const auto& privilege : privileges) { + request.add_privilege(privilege); + } + request.set_is_all_privileges(is_all_privileges); + for (const auto& grantee : grantees) { + request.add_grantee(grantee); + } + + request.set_privilege_level(privilege_level); + + ::openmldb::nameserver::GeneralResponse response; + bool ok = client_.SendRequest(&::openmldb::nameserver::NameServer_Stub::PutPrivilege, &request, &response, + FLAGS_request_timeout_ms, 1); + if (ok && response.code() == 0) { + return true; + } + return false; +} + bool NsClient::DeleteUser(const std::string& host, const std::string& name) { ::openmldb::nameserver::DeleteUserRequest request; request.set_host(host); diff --git a/src/client/ns_client.h b/src/client/ns_client.h index 73a52854765..1ddd50963bf 100644 --- a/src/client/ns_client.h +++ b/src/client/ns_client.h @@ -112,6 +112,11 @@ class NsClient : public Client { bool PutUser(const std::string& host, const std::string& name, const std::string& password); // NOLINT + bool PutPrivilege(const std::optional target_type, const std::string database, + const std::string target, const std::vector privileges, const bool is_all_privileges, + const std::vector grantees, + const ::openmldb::nameserver::PrivilegeLevel privilege_level); // NOLINT + bool DeleteUser(const std::string& host, const std::string& name); // NOLINT bool DropTable(const std::string& db, const std::string& name, diff --git a/src/cmd/sql_cmd_test.cc b/src/cmd/sql_cmd_test.cc index fe8faa21504..79225eb52dd 100644 --- a/src/cmd/sql_cmd_test.cc +++ b/src/cmd/sql_cmd_test.cc @@ -243,7 +243,6 @@ TEST_P(DBSDKTest, TestUser) { ASSERT_FALSE(status.IsOK()); sr->ExecuteSQL(absl::StrCat("CREATE USER IF NOT EXISTS user1"), &status); ASSERT_TRUE(status.IsOK()); - ASSERT_TRUE(true); auto opt = sr->GetRouterOptions(); if (cs->IsClusterMode()) { auto real_opt = std::dynamic_pointer_cast(opt); @@ -280,6 +279,121 @@ TEST_P(DBSDKTest, TestUser) { ASSERT_TRUE(status.IsOK()); } +TEST_P(DBSDKTest, TestGrantCreateUser) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + hybridse::sdk::Status status; + sr->ExecuteSQL(absl::StrCat("CREATE USER user1 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + auto opt = sr->GetRouterOptions(); + if (cs->IsClusterMode()) { + auto real_opt = std::dynamic_pointer_cast(opt); + sdk::SQLRouterOptions opt1; + opt1.zk_cluster = real_opt->zk_cluster; + opt1.zk_path = real_opt->zk_path; + opt1.user = "user1"; + opt1.password = "123456"; + auto router = NewClusterSQLRouter(opt1); + ASSERT_TRUE(router != nullptr); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_FALSE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("DROP USER user2"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("REVOKE CREATE USER ON *.* FROM 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user3 OPTIONS(password='123456')"), &status); + ASSERT_FALSE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("DROP USER user2"), &status); + ASSERT_FALSE(status.IsOK()); + } else { + auto real_opt = std::dynamic_pointer_cast(opt); + sdk::StandaloneOptions opt1; + opt1.host = real_opt->host; + opt1.port = real_opt->port; + opt1.user = "user1"; + opt1.password = "123456"; + auto router = NewStandaloneSQLRouter(opt1); + ASSERT_TRUE(router != nullptr); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_FALSE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("DROP USER user2"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("REVOKE CREATE USER ON *.* FROM 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("CREATE USER user3 OPTIONS(password='123456')"), &status); + ASSERT_FALSE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("DROP USER user2"), &status); + ASSERT_FALSE(status.IsOK()); + } + sr->ExecuteSQL(absl::StrCat("DROP USER IF EXISTS user1"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("DROP USER IF EXISTS user2"), &status); + ASSERT_TRUE(status.IsOK()); +} + +TEST_P(DBSDKTest, TestGrantCreateUserGrantOption) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + hybridse::sdk::Status status; + sr->ExecuteSQL(absl::StrCat("CREATE USER user1 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("CREATE USER user2 OPTIONS(password='123456')"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%'"), &status); + ASSERT_TRUE(status.IsOK()); + + auto opt = sr->GetRouterOptions(); + if (cs->IsClusterMode()) { + auto real_opt = std::dynamic_pointer_cast(opt); + sdk::SQLRouterOptions opt1; + opt1.zk_cluster = real_opt->zk_cluster; + opt1.zk_path = real_opt->zk_path; + opt1.user = "user1"; + opt1.password = "123456"; + auto router = NewClusterSQLRouter(opt1); + ASSERT_TRUE(router != nullptr); + router->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user2@%'"), &status); + ASSERT_FALSE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%' WITH GRANT OPTION"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user2@%'"), &status); + ASSERT_TRUE(status.IsOK()); + } else { + auto real_opt = std::dynamic_pointer_cast(opt); + sdk::StandaloneOptions opt1; + opt1.host = real_opt->host; + opt1.port = real_opt->port; + opt1.user = "user1"; + opt1.password = "123456"; + auto router = NewStandaloneSQLRouter(opt1); + ASSERT_TRUE(router != nullptr); + router->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user2@%'"), &status); + ASSERT_FALSE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user1@%' WITH GRANT OPTION"), &status); + ASSERT_TRUE(status.IsOK()); + router->ExecuteSQL(absl::StrCat("GRANT CREATE USER ON *.* TO 'user2@%'"), &status); + ASSERT_TRUE(status.IsOK()); + } + sr->ExecuteSQL(absl::StrCat("DROP USER IF EXISTS user1"), &status); + ASSERT_TRUE(status.IsOK()); + sr->ExecuteSQL(absl::StrCat("DROP USER IF EXISTS user2"), &status); + ASSERT_TRUE(status.IsOK()); +} + TEST_P(DBSDKTest, CreateDatabase) { auto cli = GetParam(); cs = cli->cs; diff --git a/src/nameserver/name_server_impl.cc b/src/nameserver/name_server_impl.cc index 9c565272fb3..5e65a7d2d94 100644 --- a/src/nameserver/name_server_impl.cc +++ b/src/nameserver/name_server_impl.cc @@ -1380,7 +1380,8 @@ void NameServerImpl::ShowTablet(RpcController* controller, const ShowTabletReque } base::Status NameServerImpl::PutUserRecord(const std::string& host, const std::string& user, - const std::string& password) { + const std::string& password, + const ::openmldb::nameserver::PrivilegeLevel privilege_level) { std::shared_ptr table_info; if (!GetTableInfo(USER_INFO_NAME, INTERNAL_DB, &table_info)) { return {ReturnCode::kTableIsNotExist, "user table does not exist"}; @@ -1391,12 +1392,8 @@ base::Status NameServerImpl::PutUserRecord(const std::string& host, const std::s row_values.push_back(user); row_values.push_back(password); row_values.push_back("0"); // password_last_changed - row_values.push_back("0"); // password_expired_time - row_values.push_back("0"); // create_time - row_values.push_back("0"); // update_time - row_values.push_back("1"); // account_type - row_values.push_back("0"); // privileges - row_values.push_back("null"); // extra_info + row_values.push_back("0"); // password_expired + row_values.push_back(PrivilegeLevel_Name(privilege_level)); // Create_user_priv std::string encoded_row; codec::RowCodec::EncodeRow(row_values, table_info->column_desc(), 1, encoded_row); @@ -1431,7 +1428,6 @@ base::Status NameServerImpl::DeleteUserRecord(const std::string& host, const std for (int meta_idx = 0; meta_idx < table_partition.partition_meta_size(); meta_idx++) { if (table_partition.partition_meta(meta_idx).is_leader() && table_partition.partition_meta(meta_idx).is_alive()) { - uint64_t cur_ts = ::baidu::common::timer::get_micros() / 1000; std::string endpoint = table_partition.partition_meta(meta_idx).endpoint(); auto table_ptr = GetTablet(endpoint); if (!table_ptr->client_->Delete(tid, 0, host + "|" + user, "index", msg)) { @@ -5640,7 +5636,8 @@ void NameServerImpl::OnLocked() { CreateDatabaseOrExit(INTERNAL_DB); if (db_table_info_[INTERNAL_DB].count(USER_INFO_NAME) == 0) { CreateSystemTableOrExit(SystemTableType::kUser); - PutUserRecord("%", "root", "1e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"); + PutUserRecord("%", "root", "1e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + ::openmldb::nameserver::PrivilegeLevel::PRIVILEGE_WITH_GRANT_OPTION); } if (IsClusterMode()) { if (tablets_.size() < FLAGS_system_table_replica_num) { @@ -9663,15 +9660,64 @@ NameServerImpl::GetSystemTableIterator() { void NameServerImpl::PutUser(RpcController* controller, const PutUserRequest* request, GeneralResponse* response, Closure* done) { brpc::ClosureGuard done_guard(done); - auto status = PutUserRecord(request->host(), request->name(), request->password()); - base::SetResponseStatus(status, response); + brpc::Controller* brpc_controller = static_cast(controller); + + if (brpc_controller->auth_context()->is_service() || + user_access_manager_.GetPrivilegeLevel(brpc_controller->auth_context()->user()) > + ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE) { + auto status = PutUserRecord(request->host(), request->name(), request->password(), + ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE); + base::SetResponseStatus(status, response); + } else { + base::SetResponseStatus(base::ReturnCode::kNotAuthorized, "not authorized to create user", response); + } +} + +void NameServerImpl::PutPrivilege(RpcController* controller, const PutPrivilegeRequest* request, + GeneralResponse* response, Closure* done) { + brpc::ClosureGuard done_guard(done); + + for (int i = 0; i < request->privilege_size(); ++i) { + auto privilege = request->privilege(i); + if (privilege == "CREATE USER") { + brpc::Controller* brpc_controller = static_cast(controller); + if (brpc_controller->auth_context()->is_service() || + user_access_manager_.GetPrivilegeLevel(brpc_controller->auth_context()->user()) >= + ::openmldb::nameserver::PrivilegeLevel::PRIVILEGE_WITH_GRANT_OPTION) { + for (int i = 0; i < request->grantee_size(); ++i) { + auto grantee = request->grantee(i); + std::size_t at_pos = grantee.find('@'); + if (at_pos != std::string::npos) { + std::string user = grantee.substr(0, at_pos); + std::string host = grantee.substr(at_pos + 1); + auto password = user_access_manager_.GetUserPassword(host, user); + if (password.has_value()) { + auto status = PutUserRecord(host, user, password.value(), request->privilege_level()); + base::SetResponseStatus(status, response); + } + } + } + } else { + base::SetResponseStatus(base::ReturnCode::kNotAuthorized, + "not authorized to grant create user privilege", response); + } + } + } } void NameServerImpl::DeleteUser(RpcController* controller, const DeleteUserRequest* request, GeneralResponse* response, Closure* done) { brpc::ClosureGuard done_guard(done); - auto status = DeleteUserRecord(request->host(), request->name()); - base::SetResponseStatus(status, response); + brpc::Controller* brpc_controller = static_cast(controller); + + if (brpc_controller->auth_context()->is_service() || + user_access_manager_.GetPrivilegeLevel(brpc_controller->auth_context()->user()) > + ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE) { + auto status = DeleteUserRecord(request->host(), request->name()); + base::SetResponseStatus(status, response); + } else { + base::SetResponseStatus(base::ReturnCode::kNotAuthorized, "not authorized to create user", response); + } } bool NameServerImpl::IsAuthenticated(const std::string& host, const std::string& username, diff --git a/src/nameserver/name_server_impl.h b/src/nameserver/name_server_impl.h index dadc335c7a3..53f44cde278 100644 --- a/src/nameserver/name_server_impl.h +++ b/src/nameserver/name_server_impl.h @@ -360,6 +360,8 @@ class NameServerImpl : public NameServer { Closure* done); void PutUser(RpcController* controller, const PutUserRequest* request, GeneralResponse* response, Closure* done); + void PutPrivilege(RpcController* controller, const PutPrivilegeRequest* request, GeneralResponse* response, + Closure* done); void DeleteUser(RpcController* controller, const DeleteUserRequest* request, GeneralResponse* response, Closure* done); bool IsAuthenticated(const std::string& host, const std::string& username, const std::string& password); @@ -373,7 +375,8 @@ class NameServerImpl : public NameServer { bool GetTableInfo(const std::string& table_name, const std::string& db_name, std::shared_ptr* table_info); - base::Status PutUserRecord(const std::string& host, const std::string& user, const std::string& password); + base::Status PutUserRecord(const std::string& host, const std::string& user, const std::string& password, + const ::openmldb::nameserver::PrivilegeLevel privilege_level); base::Status DeleteUserRecord(const std::string& host, const std::string& user); base::Status FlushPrivileges(); diff --git a/src/nameserver/system_table.h b/src/nameserver/system_table.h index cda34e1798e..03c8bc2364e 100644 --- a/src/nameserver/system_table.h +++ b/src/nameserver/system_table.h @@ -163,20 +163,16 @@ class SystemTable { break; } case SystemTableType::kUser: { - SetColumnDesc("host", type::DataType::kString, table_info->add_column_desc()); - SetColumnDesc("user", type::DataType::kString, table_info->add_column_desc()); - SetColumnDesc("password", type::DataType::kString, table_info->add_column_desc()); + SetColumnDesc("Host", type::DataType::kString, table_info->add_column_desc()); + SetColumnDesc("User", type::DataType::kString, table_info->add_column_desc()); + SetColumnDesc("authentication_string", type::DataType::kString, table_info->add_column_desc()); SetColumnDesc("password_last_changed", type::DataType::kTimestamp, table_info->add_column_desc()); - SetColumnDesc("password_expired_time", type::DataType::kBigInt, table_info->add_column_desc()); - SetColumnDesc("create_time", type::DataType::kTimestamp, table_info->add_column_desc()); - SetColumnDesc("update_time", type::DataType::kTimestamp, table_info->add_column_desc()); - SetColumnDesc("account_type", type::DataType::kInt, table_info->add_column_desc()); - SetColumnDesc("privileges", type::DataType::kString, table_info->add_column_desc()); - SetColumnDesc("extra_info", type::DataType::kString, table_info->add_column_desc()); + SetColumnDesc("password_expired", type::DataType::kTimestamp, table_info->add_column_desc()); + SetColumnDesc("Create_user_priv", type::DataType::kString, table_info->add_column_desc()); auto index = table_info->add_column_key(); index->set_index_name("index"); - index->add_col_name("host"); - index->add_col_name("user"); + index->add_col_name("Host"); + index->add_col_name("User"); auto ttl = index->mutable_ttl(); ttl->set_ttl_type(::openmldb::type::kLatestTime); ttl->set_lat_ttl(1); diff --git a/src/proto/name_server.proto b/src/proto/name_server.proto index f7c8fd5c830..14cd00d6ddd 100755 --- a/src/proto/name_server.proto +++ b/src/proto/name_server.proto @@ -544,6 +544,22 @@ message DeleteUserRequest { required string name = 2; } +enum PrivilegeLevel { + NO_PRIVILEGE = 0; + PRIVILEGE = 1; + PRIVILEGE_WITH_GRANT_OPTION = 2; +} + +message PutPrivilegeRequest { + repeated string grantee = 1; + repeated string privilege = 2; + optional string target_type = 3; + required string database = 4; + required string target = 5; + required bool is_all_privileges = 6; + required PrivilegeLevel privilege_level = 7; +} + message DeploySQLRequest { optional openmldb.api.ProcedureInfo sp_info = 3; repeated TableIndex index = 4; @@ -617,4 +633,7 @@ service NameServer { // user related interfaces rpc PutUser(PutUserRequest) returns (GeneralResponse); rpc DeleteUser(DeleteUserRequest) returns (GeneralResponse); + + // authz related interfaces + rpc PutPrivilege(PutPrivilegeRequest) returns (GeneralResponse); } diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index dbdd7dede9d..3d09156fdcc 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -2786,6 +2786,30 @@ std::shared_ptr SQLClusterRouter::ExecuteSQL( } return {}; } + case hybridse::node::kPlanTypeGrant: { + auto grant_node = dynamic_cast(node); + auto ns = cluster_sdk_->GetNsClient(); + auto ok = ns->PutPrivilege(grant_node->TargetType(), grant_node->Database(), grant_node->Target(), + grant_node->Privileges(), grant_node->IsAllPrivileges(), grant_node->Grantees(), + grant_node->WithGrantOption() + ? ::openmldb::nameserver::PrivilegeLevel::PRIVILEGE_WITH_GRANT_OPTION + : ::openmldb::nameserver::PrivilegeLevel::PRIVILEGE); + if (!ok) { + *status = {StatusCode::kCmdError, "Grant API call failed"}; + } + return {}; + } + case hybridse::node::kPlanTypeRevoke: { + auto revoke_node = dynamic_cast(node); + auto ns = cluster_sdk_->GetNsClient(); + auto ok = ns->PutPrivilege(revoke_node->TargetType(), revoke_node->Database(), revoke_node->Target(), + revoke_node->Privileges(), revoke_node->IsAllPrivileges(), + revoke_node->Grantees(), ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE); + if (!ok) { + *status = {StatusCode::kCmdError, "Revoke API call failed"}; + } + return {}; + } case hybridse::node::kPlanTypeAlterUser: { auto alter_node = dynamic_cast(node); UserInfo user_info; From 289b746bc302388605d31c2c8d4e256a54f0ca46 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 26 Jun 2024 18:26:01 +0800 Subject: [PATCH 20/23] build(deps-dev): bump requests from 2.31.0 to 2.32.2 in /docs (#3951) Bumps [requests](https://github.com/psf/requests) from 2.31.0 to 2.32.2. - [Release notes](https://github.com/psf/requests/releases) - [Changelog](https://github.com/psf/requests/blob/main/HISTORY.md) - [Commits](https://github.com/psf/requests/compare/v2.31.0...v2.32.2) --- updated-dependencies: - dependency-name: requests dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/poetry.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/poetry.lock b/docs/poetry.lock index 39577275304..2e0e52f5a21 100644 --- a/docs/poetry.lock +++ b/docs/poetry.lock @@ -425,13 +425,13 @@ files = [ [[package]] name = "requests" -version = "2.31.0" +version = "2.32.2" description = "Python HTTP for Humans." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, - {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, + {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"}, + {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"}, ] [package.dependencies] From ca7f7e2085c3634e2087668d88b6fa320a75190a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 26 Jun 2024 18:26:12 +0800 Subject: [PATCH 21/23] build(deps-dev): bump org.apache.derby:derby (#3949) Bumps org.apache.derby:derby from 10.14.2.0 to 10.17.1.0. --- updated-dependencies: - dependency-name: org.apache.derby:derby dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- extensions/kafka-connect-jdbc/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/kafka-connect-jdbc/pom.xml b/extensions/kafka-connect-jdbc/pom.xml index a0c3cf0512d..2f35d1bd1e1 100644 --- a/extensions/kafka-connect-jdbc/pom.xml +++ b/extensions/kafka-connect-jdbc/pom.xml @@ -53,7 +53,7 @@ - 10.14.2.0 + 10.17.1.0 2.7 0.11.1 3.41.2.2 From 25bd745badeefd54d3f8b0c75c8c5d69ceb58a42 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 26 Jun 2024 18:26:18 +0800 Subject: [PATCH 22/23] build(deps): bump org.postgresql:postgresql (#3950) Bumps [org.postgresql:postgresql](https://github.com/pgjdbc/pgjdbc) from 42.3.3 to 42.3.9. - [Release notes](https://github.com/pgjdbc/pgjdbc/releases) - [Changelog](https://github.com/pgjdbc/pgjdbc/blob/master/CHANGELOG.md) - [Commits](https://github.com/pgjdbc/pgjdbc/compare/REL42.3.3...REL42.3.9) --- updated-dependencies: - dependency-name: org.postgresql:postgresql dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- extensions/kafka-connect-jdbc/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/kafka-connect-jdbc/pom.xml b/extensions/kafka-connect-jdbc/pom.xml index 2f35d1bd1e1..5d911a342c1 100644 --- a/extensions/kafka-connect-jdbc/pom.xml +++ b/extensions/kafka-connect-jdbc/pom.xml @@ -59,7 +59,7 @@ 3.41.2.2 19.7.0.0 8.4.1.jre8 - 42.3.3 + 42.3.9 1.3.1 0.8.5 Confluent Community License From 1c1e2134316c2cdf51be1d55beb9c41851cc8175 Mon Sep 17 00:00:00 2001 From: HuangWei Date: Mon, 1 Jul 2024 11:45:43 +0800 Subject: [PATCH 23/23] feat: iot table (#3944) * feat: iot table * fix * fix * fix delete key entry * fix comment * ut * ut test * fix ut * sleep more for truncate * sleep 16 * tool pytest fix and swig fix * fix * clean * move to base * fix * fix coverage ut * fix --------- Co-authored-by: Huang Wei --- .../ddl/CREATE_INDEX_STATEMENT.md | 6 + .../ddl/CREATE_TABLE_STATEMENT.md | 23 +- hybridse/include/node/node_manager.h | 4 +- hybridse/include/node/sql_node.h | 13 +- hybridse/src/node/node_manager.cc | 9 +- hybridse/src/node/plan_node_test.cc | 2 +- hybridse/src/node/sql_node.cc | 2 + hybridse/src/node/sql_node_test.cc | 2 +- hybridse/src/plan/planner.cc | 3 +- hybridse/src/planv2/ast_node_converter.cc | 37 +- hybridse/src/sdk/codec_sdk.cc | 2 +- src/base/index_util.cc | 121 ++++ src/base/index_util.h | 45 ++ src/base/status.h | 2 +- src/base/status_util.h | 7 + src/catalog/distribute_iterator.cc | 2 +- src/client/tablet_client.cc | 147 ++-- src/client/tablet_client.h | 16 +- src/cmd/display.h | 3 +- src/cmd/sql_cmd_test.cc | 18 +- src/codec/field_codec.h | 4 +- src/flags.cc | 4 + src/nameserver/name_server_impl.cc | 20 +- src/proto/common.proto | 7 + src/proto/tablet.proto | 1 + src/sdk/node_adapter.cc | 7 + src/sdk/option.h | 14 +- src/sdk/sql_cluster_router.cc | 364 ++++++++-- src/storage/index_organized_table.cc | 634 ++++++++++++++++++ src/storage/index_organized_table.h | 68 ++ src/storage/iot_segment.cc | 412 ++++++++++++ src/storage/iot_segment.h | 298 ++++++++ src/storage/iot_segment_test.cc | 517 ++++++++++++++ src/storage/key_entry.cc | 2 +- src/storage/mem_table.cc | 72 +- src/storage/mem_table.h | 25 +- src/storage/mem_table_iterator.cc | 7 +- src/storage/mem_table_iterator.h | 24 +- src/storage/node_cache.cc | 5 +- src/storage/record.h | 17 +- src/storage/schema.cc | 25 +- src/storage/schema.h | 35 +- src/storage/segment.cc | 81 ++- src/storage/segment.h | 25 +- src/storage/table.cc | 4 +- src/storage/table_iterator_test.cc | 12 +- src/tablet/tablet_impl.cc | 214 ++++-- src/tablet/tablet_impl_test.cc | 14 +- steps/test_python.sh | 4 +- tools/tool.py | 11 +- 50 files changed, 3082 insertions(+), 309 deletions(-) create mode 100644 src/base/index_util.cc create mode 100644 src/base/index_util.h create mode 100644 src/storage/index_organized_table.cc create mode 100644 src/storage/index_organized_table.h create mode 100644 src/storage/iot_segment.cc create mode 100644 src/storage/iot_segment.h create mode 100644 src/storage/iot_segment_test.cc diff --git a/docs/zh/openmldb_sql/ddl/CREATE_INDEX_STATEMENT.md b/docs/zh/openmldb_sql/ddl/CREATE_INDEX_STATEMENT.md index abfa201ab29..0ed5661e993 100644 --- a/docs/zh/openmldb_sql/ddl/CREATE_INDEX_STATEMENT.md +++ b/docs/zh/openmldb_sql/ddl/CREATE_INDEX_STATEMENT.md @@ -55,6 +55,12 @@ CREATE INDEX index3 ON t5 (col3) OPTIONS (ts=ts1, ttl_type=absolute, ttl=30d); ``` 关于`TTL`和`TTL_TYPE`的更多信息参考[这里](./CREATE_TABLE_STATEMENT.md) +IOT表创建不同类型的索引,不指定type创建Covering索引,指定type为secondary,创建Secondary索引: +```SQL +CREATE INDEX index_s ON t5 (col3) OPTIONS (ts=ts1, ttl_type=absolute, ttl=30d, type=secondary); +``` +同keys和ts列的索引被视为同一个索引,不要尝试建立不同type的同一索引。 + ## 相关SQL [DROP INDEX](./DROP_INDEX_STATEMENT.md) diff --git a/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md b/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md index 750b198d897..895cd5f43f6 100644 --- a/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md +++ b/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md @@ -223,7 +223,7 @@ IndexOption ::= | 配置项 | 描述 | expr | 用法示例 | |------------|---------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------| -| `KEY` | 索引列(必选)。OpenMLDB支持单列索引,也支持联合索引。当`KEY`后只有一列时,仅在该列上建立索引。当`KEY`后有多列时,建立这几列的联合索引:将多列按顺序拼接成一个字符串作为索引。 | 支持单列索引:`ColumnName`
或联合索引:
`(ColumnName (, ColumnName)* ) ` | 单列索引:`INDEX(KEY=col1)`
联合索引:`INDEX(KEY=(col1, col2))` | +| `KEY/CKEY/SKEY` | 索引列(必选)。OpenMLDB支持单列索引,也支持联合索引。当`KEY`后只有一列时,仅在该列上建立索引。当`KEY`后有多列时,建立这几列的联合索引:将多列按顺序拼接成一个字符串作为索引。多KEY使用见[Index-Orgnized Table(IOT)](#index-orgnized-tableiot)。 | 支持单列索引:`ColumnName`
或联合索引:
`(ColumnName (, ColumnName)* ) ` | 单列索引:`INDEX(KEY=col1)`
联合索引:`INDEX(KEY=(col1, col2))` | | `TS` | 索引时间列(可选)。同一个索引上的数据将按照时间索引列排序。当不显式配置`TS`时,使用数据插入的时间戳作为索引时间。时间列的类型只能为BigInt或者Timestamp | `ColumnName` | `INDEX(KEY=col1, TS=std_time)`。索引列为col1,col1相同的数据行按std_time排序。 | | `TTL_TYPE` | 淘汰规则(可选)。包括四种类型,当不显式配置`TTL_TYPE`时,默认使用`ABSOLUTE`过期配置。 | 支持的expr如下:`ABSOLUTE`
`LATEST`
`ABSORLAT`
`ABSANDLAT`。 | 具体用法可以参考下文“TTL和TTL_TYPE的配置细则” | | `TTL` | 最大存活时间/条数(可选)。依赖于`TTL_TYPE`,不同的`TTL_TYPE`有不同的`TTL` 配置方式。当不显式配置`TTL`时,`TTL=0`,表示不设置淘汰规则,OpenMLDB将不会淘汰记录。 | 支持数值:`int_literal`
或数值带时间单位(`S,M,H,D`):`interval_literal`
或元组形式:`( interval_literal , int_literal )` |具体用法可以参考下文“TTL和TTL_TYPE的配置细则” | @@ -240,6 +240,27 @@ IndexOption ::= ```{note} 最大过期时间和最大存活条数的限制,是出于性能考虑。如果你一定要配置更大的TTL值,可先创建表时临时使用合规的TTL值,然后使用nameserver的UpdateTTL接口来调整到所需的值(可无视max限制),生效需要经过一个gc时间;或者,调整nameserver配置`absolute_ttl_max`和`latest_ttl_max`,重启生效后再创建表。 ``` +#### Index-Orgnized Table(IOT) + +索引使用KEY设置时创建Covering索引,在OpenMLDB中Covering索引存储完整的数据行,也因此占用内存较多。如果希望内存占用更低,同时允许性能损失,可以使用IOT表。IOT表中可以建三种类型的索引: +- `CKEY`:Clustered索引,存完整数据行。配置的CKEY+TS用于唯一标识一行数据,INSERT重复主键时将更新数据(会触发所有索引上的删除旧数据,再INSERT新数据,性能会有损失)。也可只使用CKEY,不配置TS,CKEY唯一标识一行数据。查询到此索引的性能无损失。 +- `SKEY`:Secondary索引,存主键。不配置TS时,同SKEY下按插入时间排序。查询时先在Secondary索引中找到对应主键值,再根据主键查数据,查询性能有损失。 +- `KEY`:Covering索引,存完整数据行。不配置TS时,同KEY下按插入时间排序。查询到此索引的性能无损失。 + +创建IOT表,第一个索引必须是唯一一个Clustered索引,其他索引可选。暂不支持调整Clustered索引的顺序。 + +```sql +CREATE TABLE iot (c1 int64, c2 int64, c3 int64, INDEX(ckey=c1, ts=c2)); -- 一个Clustered索引 +CREATE TABLE iot (c1 int64, c2 int64, c3 int64, INDEX(ckey=c1), INDEX(skey=c2)); -- 一个Clustered索引和一个Secondary索引 +CREATE TABLE iot (c1 int64, c2 int64, c3 int64, INDEX(ckey=c1), INDEX(skey=c2), INDEX(key=c3)); -- 一个Clustered索引、一个Secondary索引和一个Covering索引 +``` + +IOT各个索引的TTL与普通表的不同点是,IOT Clustered索引的ttl淘汰,将触发其他索引的删除操作,而Secondary索引和Covering索引的ttl淘汰,只会删除自身索引中的数据,不会触发其他索引的删除操作。通常来讲,除非有必要让Secondary和Covering索引更加节约内存,可以只设置Clustered索引的ttl,不设置Secondary和Covering索引的ttl。 + +##### 注意事项 + +- IOT表不可以并发写入相同主键的多条数据,可能出现冲突,至少一条数据会写入失败。IOT表中已存在的相同主键的数据不需要额外处理,将会被覆盖。为了不用修复导入,请在导入前做好数据清洗,对导入数据中相同主键的数据进行去重。(覆盖会出触发所有索引中的删除,单线程写入效率也非常低,所以并不推荐单线程导入。) +- #### Example **示例1:创建一张带单列索引的表** diff --git a/hybridse/include/node/node_manager.h b/hybridse/include/node/node_manager.h index 9fc217d6f82..bc29b484f16 100644 --- a/hybridse/include/node/node_manager.h +++ b/hybridse/include/node/node_manager.h @@ -173,8 +173,8 @@ class NodeManager { SqlNode *MakeColumnIndexNode(SqlNodeList *keys, SqlNode *ts, SqlNode *ttl, SqlNode *version); SqlNode *MakeColumnIndexNode(SqlNodeList *index_item_list); - SqlNode *MakeIndexKeyNode(const std::string &key); - SqlNode *MakeIndexKeyNode(const std::vector &keys); + SqlNode *MakeIndexKeyNode(const std::string &key, const std::string &type); + SqlNode *MakeIndexKeyNode(const std::vector &keys, const std::string &type); SqlNode *MakeIndexTsNode(const std::string &ts); SqlNode *MakeIndexTTLNode(ExprListNode *ttl_expr); SqlNode *MakeIndexTTLTypeNode(const std::string &ttl_type); diff --git a/hybridse/include/node/sql_node.h b/hybridse/include/node/sql_node.h index 52542426c2a..eec63833617 100644 --- a/hybridse/include/node/sql_node.h +++ b/hybridse/include/node/sql_node.h @@ -2084,14 +2084,19 @@ class CreateStmt : public SqlNode { class IndexKeyNode : public SqlNode { public: IndexKeyNode() : SqlNode(kIndexKey, 0, 0) {} - explicit IndexKeyNode(const std::string &key) : SqlNode(kIndexKey, 0, 0), key_({key}) {} - explicit IndexKeyNode(const std::vector &keys) : SqlNode(kIndexKey, 0, 0), key_(keys) {} + explicit IndexKeyNode(const std::string &key, const std::string &type) + : SqlNode(kIndexKey, 0, 0), key_({key}), index_type_(type) {} + explicit IndexKeyNode(const std::vector &keys, const std::string &type) + : SqlNode(kIndexKey, 0, 0), key_(keys), index_type_(type) {} ~IndexKeyNode() {} void AddKey(const std::string &key) { key_.push_back(key); } + void SetIndexType(const std::string &type) { index_type_ = type; } std::vector &GetKey() { return key_; } + std::string &GetIndexType() { return index_type_; } private: std::vector key_; + std::string index_type_ = "key"; }; class IndexVersionNode : public SqlNode { public: @@ -2145,6 +2150,7 @@ class ColumnIndexNode : public SqlNode { public: ColumnIndexNode() : SqlNode(kColumnIndex, 0, 0), + index_type_("key"), ts_(""), version_(""), version_count_(0), @@ -2155,6 +2161,8 @@ class ColumnIndexNode : public SqlNode { std::vector &GetKey() { return key_; } void SetKey(const std::vector &key) { key_ = key; } + void SetIndexType(const std::string &type) { index_type_ = type; } + std::string &GetIndexType() { return index_type_; } std::string GetTs() const { return ts_; } @@ -2183,6 +2191,7 @@ class ColumnIndexNode : public SqlNode { private: std::vector key_; + std::string index_type_; std::string ts_; std::string version_; int version_count_; diff --git a/hybridse/src/node/node_manager.cc b/hybridse/src/node/node_manager.cc index ffa1fe2092f..91936235000 100644 --- a/hybridse/src/node/node_manager.cc +++ b/hybridse/src/node/node_manager.cc @@ -451,6 +451,7 @@ SqlNode *NodeManager::MakeColumnIndexNode(SqlNodeList *index_item_list) { switch (node_ptr->GetType()) { case kIndexKey: index_ptr->SetKey(dynamic_cast(node_ptr)->GetKey()); + index_ptr->SetIndexType(dynamic_cast(node_ptr)->GetIndexType()); break; case kIndexTs: index_ptr->SetTs(dynamic_cast(node_ptr)->GetColumnName()); @@ -649,12 +650,12 @@ FnParaNode *NodeManager::MakeFnParaNode(const std::string &name, const TypeNode ::hybridse::node::FnParaNode *para_node = new ::hybridse::node::FnParaNode(expr_id); return RegisterNode(para_node); } -SqlNode *NodeManager::MakeIndexKeyNode(const std::string &key) { - SqlNode *node_ptr = new IndexKeyNode(key); +SqlNode *NodeManager::MakeIndexKeyNode(const std::string &key, const std::string &type) { + SqlNode *node_ptr = new IndexKeyNode(key, type); return RegisterNode(node_ptr); } -SqlNode *NodeManager::MakeIndexKeyNode(const std::vector &keys) { - SqlNode *node_ptr = new IndexKeyNode(keys); +SqlNode *NodeManager::MakeIndexKeyNode(const std::vector &keys, const std::string &type) { + SqlNode *node_ptr = new IndexKeyNode(keys, type); return RegisterNode(node_ptr); } SqlNode *NodeManager::MakeIndexTsNode(const std::string &ts) { diff --git a/hybridse/src/node/plan_node_test.cc b/hybridse/src/node/plan_node_test.cc index aac111f8bf3..68eb0349a71 100644 --- a/hybridse/src/node/plan_node_test.cc +++ b/hybridse/src/node/plan_node_test.cc @@ -228,7 +228,7 @@ TEST_F(PlanNodeTest, MultiPlanNodeTest) { TEST_F(PlanNodeTest, ExtractColumnsAndIndexsTest) { SqlNodeList *index_items = manager_->MakeNodeList(); - index_items->PushBack(manager_->MakeIndexKeyNode("col4")); + index_items->PushBack(manager_->MakeIndexKeyNode("col4", "key")); index_items->PushBack(manager_->MakeIndexTsNode("col5")); ColumnIndexNode *index_node = dynamic_cast(manager_->MakeColumnIndexNode(index_items)); index_node->SetName("index1"); diff --git a/hybridse/src/node/sql_node.cc b/hybridse/src/node/sql_node.cc index 5055b7dabb2..05dc87e34d6 100644 --- a/hybridse/src/node/sql_node.cc +++ b/hybridse/src/node/sql_node.cc @@ -1188,6 +1188,8 @@ static absl::flat_hash_map CreateSqlNodeTypeToNa {kCreateFunctionStmt, "kCreateFunctionStmt"}, {kCreateUserStmt, "kCreateUserStmt"}, {kAlterUserStmt, "kAlterUserStmt"}, + {kRevokeStmt, "kRevokeStmt"}, + {kGrantStmt, "kGrantStmt"}, {kDynamicUdfFnDef, "kDynamicUdfFnDef"}, {kDynamicUdafFnDef, "kDynamicUdafFnDef"}, {kWithClauseEntry, "kWithClauseEntry"}, diff --git a/hybridse/src/node/sql_node_test.cc b/hybridse/src/node/sql_node_test.cc index 67bb861a812..c67a21b31d7 100644 --- a/hybridse/src/node/sql_node_test.cc +++ b/hybridse/src/node/sql_node_test.cc @@ -666,7 +666,7 @@ TEST_F(SqlNodeTest, IndexVersionNodeTest) { TEST_F(SqlNodeTest, CreateIndexNodeTest) { SqlNodeList *index_items = node_manager_->MakeNodeList(); - index_items->PushBack(node_manager_->MakeIndexKeyNode("col4")); + index_items->PushBack(node_manager_->MakeIndexKeyNode("col4", "key")); index_items->PushBack(node_manager_->MakeIndexTsNode("col5")); ColumnIndexNode *index_node = dynamic_cast(node_manager_->MakeColumnIndexNode(index_items)); CreatePlanNode *node = node_manager_->MakeCreateTablePlanNode( diff --git a/hybridse/src/plan/planner.cc b/hybridse/src/plan/planner.cc index 3a3984c9b16..1dfed24f39a 100644 --- a/hybridse/src/plan/planner.cc +++ b/hybridse/src/plan/planner.cc @@ -1139,7 +1139,7 @@ bool Planner::ExpandCurrentHistoryWindow(std::vector index_names; @@ -1199,7 +1199,6 @@ base::Status Planner::TransformTableDef(const std::string &table_name, const Nod case node::kColumnIndex: { node::ColumnIndexNode *column_index = static_cast(column_desc); - if (column_index->GetName().empty()) { column_index->SetName(PlanAPI::GenerateName("INDEX", table->indexes_size())); } diff --git a/hybridse/src/planv2/ast_node_converter.cc b/hybridse/src/planv2/ast_node_converter.cc index 23e56924ae2..ae6f7815ff2 100644 --- a/hybridse/src/planv2/ast_node_converter.cc +++ b/hybridse/src/planv2/ast_node_converter.cc @@ -1598,7 +1598,7 @@ base::Status ConvertColumnIndexNode(const zetasql::ASTIndexDefinition* ast_def_n } // case entry->name() -// "key" -> IndexKeyNode +// "key"/"ckey"/"skey" -> IndexKeyNode // "ts" -> IndexTsNode // "ttl" -> IndexTTLNode // "ttl_type" -> IndexTTLTypeNode @@ -1607,14 +1607,13 @@ base::Status ConvertIndexOption(const zetasql::ASTOptionsEntry* entry, node::Nod node::SqlNode** output) { auto name = entry->name()->GetAsString(); absl::string_view name_v(name); - if (absl::EqualsIgnoreCase("key", name_v)) { + if (absl::EqualsIgnoreCase("key", name_v) || absl::EqualsIgnoreCase("ckey", name_v) || absl::EqualsIgnoreCase("skey", name_v)) { switch (entry->value()->node_kind()) { case zetasql::AST_PATH_EXPRESSION: { std::string column_name; CHECK_STATUS( AstPathExpressionToString(entry->value()->GetAsOrNull(), &column_name)); - *output = node_manager->MakeIndexKeyNode(column_name); - + *output = node_manager->MakeIndexKeyNode(column_name, absl::AsciiStrToLower(name_v)); return base::Status::OK(); } case zetasql::AST_STRUCT_CONSTRUCTOR_WITH_PARENS: { @@ -1632,7 +1631,7 @@ base::Status ConvertIndexOption(const zetasql::ASTOptionsEntry* entry, node::Nod ast_struct_expr->field_expression(0)->GetAsOrNull(), &key_str)); node::IndexKeyNode* index_keys = - dynamic_cast(node_manager->MakeIndexKeyNode(key_str)); + dynamic_cast(node_manager->MakeIndexKeyNode(key_str, absl::AsciiStrToLower(name_v))); for (int i = 1; i < field_expr_len; ++i) { std::string key; @@ -1643,7 +1642,6 @@ base::Status ConvertIndexOption(const zetasql::ASTOptionsEntry* entry, node::Nod index_keys->AddKey(key); } *output = index_keys; - return base::Status::OK(); } default: { @@ -2256,13 +2254,34 @@ base::Status ConvertCreateIndexStatement(const zetasql::ASTCreateIndexStatement* keys.push_back(path.back()); } node::SqlNodeList* index_node_list = node_manager->MakeNodeList(); - - node::SqlNode* index_key_node = node_manager->MakeIndexKeyNode(keys); + // extract index type from options + std::string index_type{"key"}; + if (root->options_list() != nullptr) { + for (const auto option : root->options_list()->options_entries()) { + if (auto name = option->name()->GetAsString(); absl::EqualsIgnoreCase(name, "type")) { + CHECK_TRUE(option->value()->node_kind() == zetasql::AST_PATH_EXPRESSION, common::kSqlAstError, + "Invalid index type, should be path expression"); + std::string type_name; + CHECK_STATUS( + AstPathExpressionToString(option->value()->GetAsOrNull(), &type_name)); + if (absl::EqualsIgnoreCase(type_name, "secondary")) { + index_type = "skey"; + } else if (!absl::EqualsIgnoreCase(type_name, "covering")) { + FAIL_STATUS(common::kSqlAstError, "Invalid index type: ", type_name); + } + } + } + } + node::SqlNode* index_key_node = node_manager->MakeIndexKeyNode(keys, index_type); index_node_list->PushBack(index_key_node); if (root->options_list() != nullptr) { for (const auto option : root->options_list()->options_entries()) { + // ignore type + if (auto name = option->name()->GetAsString(); absl::EqualsIgnoreCase(name, "type")) { + continue; + } node::SqlNode* node = nullptr; - CHECK_STATUS(ConvertIndexOption(option, node_manager, &node)); + CHECK_STATUS(ConvertIndexOption(option, node_manager, &node)); // option set secondary index type if (node != nullptr) { // NOTE: unhandled option will return OK, but node is not set index_node_list->PushBack(node); diff --git a/hybridse/src/sdk/codec_sdk.cc b/hybridse/src/sdk/codec_sdk.cc index 9b910dd28cd..c09216b2600 100644 --- a/hybridse/src/sdk/codec_sdk.cc +++ b/hybridse/src/sdk/codec_sdk.cc @@ -73,7 +73,7 @@ bool RowIOBufView::Reset(const butil::IOBuf& buf) { return false; } str_addr_length_ = codec::GetAddrLength(size_); - DLOG(INFO) << "size " << size_ << " addr length " << str_addr_length_; + DLOG(INFO) << "size " << size_ << " addr length " << (unsigned int)str_addr_length_; return true; } diff --git a/src/base/index_util.cc b/src/base/index_util.cc new file mode 100644 index 00000000000..679ce2deaa7 --- /dev/null +++ b/src/base/index_util.cc @@ -0,0 +1,121 @@ +/* + * Copyright 2021 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "base/index_util.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "base/glog_wrapper.h" +#include "storage/schema.h" + +namespace openmldb::base { +// , error if empty +std::map> MakePkeysHint(const codec::Schema& schema, + const common::ColumnKey& cidx_ck) { + if (cidx_ck.col_name().empty()) { + LOG(WARNING) << "empty cidx column key"; + return {}; + } + // pkey col idx in row + std::set pkey_set; + for (int i = 0; i < cidx_ck.col_name().size(); i++) { + pkey_set.insert(cidx_ck.col_name().Get(i)); + } + if (pkey_set.empty()) { + LOG(WARNING) << "empty pkey set"; + return {}; + } + if (pkey_set.size() != static_cast::size_type>(cidx_ck.col_name().size())) { + LOG(WARNING) << "pkey set size not equal to cidx pkeys size"; + return {}; + } + std::map> col_idx; + for (int i = 0; i < schema.size(); i++) { + if (pkey_set.find(schema.Get(i).name()) != pkey_set.end()) { + col_idx[schema.Get(i).name()] = {i, schema.Get(i).data_type()}; + } + } + if (col_idx.size() != pkey_set.size()) { + LOG(WARNING) << "col idx size not equal to cidx pkeys size"; + return {}; + } + return col_idx; +} + +// error if empty +std::string MakeDeleteSQL(const std::string& db, const std::string& name, const common::ColumnKey& cidx_ck, + const int8_t* values, uint64_t ts, const codec::RowView& row_view, + const std::map>& col_idx) { + auto sql_prefix = absl::StrCat("delete from ", db, ".", name, " where "); + std::string cond; + for (int i = 0; i < cidx_ck.col_name().size(); i++) { + // append primary keys, pkeys in dimension are encoded, so we should get them from raw value + // split can't work if string has `|` + auto& col_name = cidx_ck.col_name().Get(i); + auto col = col_idx.find(col_name); + if (col == col_idx.end()) { + LOG(WARNING) << "col " << col_name << " not found in col idx"; + return ""; + } + std::string val; + row_view.GetStrValue(values, col->second.first, &val); + if (!cond.empty()) { + absl::StrAppend(&cond, " and "); + } + // TODO(hw): string should add quotes how about timestamp? + // check existence before, so here we skip + absl::StrAppend(&cond, col_name); + if (auto t = col->second.second; t == type::kVarchar || t == type::kString) { + absl::StrAppend(&cond, "=\"", val, "\""); + } else { + absl::StrAppend(&cond, "=", val); + } + } + // ts must be integer, won't be string + if (!cidx_ck.ts_name().empty() && cidx_ck.ts_name() != storage::DEFAULT_TS_COL_NAME) { + if (!cond.empty()) { + absl::StrAppend(&cond, " and "); + } + absl::StrAppend(&cond, cidx_ck.ts_name(), "=", std::to_string(ts)); + } + auto sql = absl::StrCat(sql_prefix, cond, ";"); + // TODO(hw): if delete failed, we can't revert. And if sidx skeys+sts doesn't change, no need to delete and + // then insert + DLOG(INFO) << "delete sql " << sql; + return sql; +} + +// error if empty +std::string ExtractPkeys(const common::ColumnKey& cidx_ck, const int8_t* values, const codec::RowView& row_view, + const std::map>& col_idx) { + // join with | + std::vector pkeys; + for (int i = 0; i < cidx_ck.col_name().size(); i++) { + auto& col_name = cidx_ck.col_name().Get(i); + auto col = col_idx.find(col_name); + if (col == col_idx.end()) { + LOG(WARNING) << "col " << col_name << " not found in col idx"; + return ""; + } + std::string val; + row_view.GetStrValue(values, col->second.first, &val); + pkeys.push_back(val); + } + return absl::StrJoin(pkeys, "|"); +} + +} // namespace openmldb::base diff --git a/src/base/index_util.h b/src/base/index_util.h new file mode 100644 index 00000000000..11392b37bf0 --- /dev/null +++ b/src/base/index_util.h @@ -0,0 +1,45 @@ +/* + * Copyright 2021 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SRC_BASE_INDEX_UTIL_H_ +#define SRC_BASE_INDEX_UTIL_H_ + +#include + +#include "codec/codec.h" + +namespace openmldb { +namespace base { + +// don't declare func in table header cuz swig sdk + +// , error if empty +std::map> MakePkeysHint(const codec::Schema& schema, + const common::ColumnKey& cidx_ck); + +// error if empty +std::string MakeDeleteSQL(const std::string& db, const std::string& name, const common::ColumnKey& cidx_ck, + const int8_t* values, uint64_t ts, const codec::RowView& row_view, + const std::map>& col_idx); + +// error if empty +std::string ExtractPkeys(const common::ColumnKey& cidx_ck, const int8_t* values, const codec::RowView& row_view, + const std::map>& col_idx); + +} // namespace base +} // namespace openmldb + +#endif // SRC_BASE_INDEX_UTIL_H_ diff --git a/src/base/status.h b/src/base/status.h index a2da254e78e..bde6a31960a 100644 --- a/src/base/status.h +++ b/src/base/status.h @@ -191,7 +191,7 @@ enum ReturnCode { }; struct Status { - Status(int code_i, std::string msg_i) : code(code_i), msg(msg_i) {} + Status(int code_i, const std::string& msg_i) : code(code_i), msg(msg_i) {} Status() : code(ReturnCode::kOk), msg("ok") {} inline bool OK() const { return code == ReturnCode::kOk; } inline const std::string& GetMsg() const { return msg; } diff --git a/src/base/status_util.h b/src/base/status_util.h index 1d0db238d61..e0bd5758304 100644 --- a/src/base/status_util.h +++ b/src/base/status_util.h @@ -161,6 +161,13 @@ LOG(WARNING) << "Status: " << _s->ToString(); \ } while (0) +#define APPEND_AND_WARN(s, msg) \ + do { \ + ::hybridse::sdk::Status* _s = (s); \ + _s->Append((msg)); \ + LOG(WARNING) << "Status: " << _s->ToString(); \ + } while (0) + /// @brief s.msg += append_str, and warn it #define CODE_APPEND_AND_WARN(s, code, msg) \ do { \ diff --git a/src/catalog/distribute_iterator.cc b/src/catalog/distribute_iterator.cc index 032d3ec75f2..519dec5f2fa 100644 --- a/src/catalog/distribute_iterator.cc +++ b/src/catalog/distribute_iterator.cc @@ -155,7 +155,7 @@ bool FullTableIterator::NextFromRemote() { } } else { kv_it_ = iter->second->Traverse(tid_, cur_pid_, "", "", 0, FLAGS_traverse_cnt_limit, false, 0, count); - DLOG(INFO) << "count " << count; + DVLOG(1) << "count " << count; } if (kv_it_ && kv_it_->Valid()) { last_pk_ = kv_it_->GetLastPK(); diff --git a/src/client/tablet_client.cc b/src/client/tablet_client.cc index cbfb794817f..635e37d8f78 100644 --- a/src/client/tablet_client.cc +++ b/src/client/tablet_client.cc @@ -164,7 +164,7 @@ base::Status TabletClient::TruncateTable(uint32_t tid, uint32_t pid) { request.set_tid(tid); request.set_pid(pid); if (!client_.SendRequest(&::openmldb::api::TabletServer_Stub::TruncateTable, &request, &response, - FLAGS_request_timeout_ms, 1)) { + FLAGS_request_timeout_ms, 1)) { return {base::ReturnCode::kRPCError, "send request failed!"}; } else if (response.code() == 0) { return {}; @@ -178,7 +178,7 @@ base::Status TabletClient::CreateTable(const ::openmldb::api::TableMeta& table_m table_meta_ptr->CopyFrom(table_meta); ::openmldb::api::CreateTableResponse response; if (!client_.SendRequest(&::openmldb::api::TabletServer_Stub::CreateTable, &request, &response, - FLAGS_request_timeout_ms * 2, 1)) { + FLAGS_request_timeout_ms * 2, 1)) { return {base::ReturnCode::kRPCError, "send request failed!"}; } else if (response.code() == 0) { return {}; @@ -207,9 +207,8 @@ bool TabletClient::UpdateTableMetaForAddField(uint32_t tid, const std::vector>& dimensions, - int memory_usage_limit, bool put_if_absent) { - + const std::vector>& dimensions, int memory_usage_limit, + bool put_if_absent, bool check_exists) { ::google::protobuf::RepeatedPtrField<::openmldb::api::Dimension> pb_dimensions; for (size_t i = 0; i < dimensions.size(); i++) { ::openmldb::api::Dimension* d = pb_dimensions.Add(); @@ -217,12 +216,12 @@ base::Status TabletClient::Put(uint32_t tid, uint32_t pid, uint64_t time, const d->set_idx(dimensions[i].second); } - return Put(tid, pid, time, base::Slice(value), &pb_dimensions, memory_usage_limit, put_if_absent); + return Put(tid, pid, time, base::Slice(value), &pb_dimensions, memory_usage_limit, put_if_absent, check_exists); } base::Status TabletClient::Put(uint32_t tid, uint32_t pid, uint64_t time, const base::Slice& value, - ::google::protobuf::RepeatedPtrField<::openmldb::api::Dimension>* dimensions, - int memory_usage_limit, bool put_if_absent) { + ::google::protobuf::RepeatedPtrField<::openmldb::api::Dimension>* dimensions, + int memory_usage_limit, bool put_if_absent, bool check_exists) { ::openmldb::api::PutRequest request; if (memory_usage_limit < 0 || memory_usage_limit > 100) { return {base::ReturnCode::kError, absl::StrCat("invalid memory_usage_limit ", memory_usage_limit)}; @@ -235,9 +234,10 @@ base::Status TabletClient::Put(uint32_t tid, uint32_t pid, uint64_t time, const request.set_pid(pid); request.mutable_dimensions()->Swap(dimensions); request.set_put_if_absent(put_if_absent); + request.set_check_exists(check_exists); ::openmldb::api::PutResponse response; - auto st = client_.SendRequestSt(&::openmldb::api::TabletServer_Stub::Put, - &request, &response, FLAGS_request_timeout_ms, 1); + auto st = client_.SendRequestSt(&::openmldb::api::TabletServer_Stub::Put, &request, &response, + FLAGS_request_timeout_ms, 1); if (!st.OK()) { return st; } @@ -245,7 +245,7 @@ base::Status TabletClient::Put(uint32_t tid, uint32_t pid, uint64_t time, const } base::Status TabletClient::Put(uint32_t tid, uint32_t pid, const std::string& pk, uint64_t time, - const std::string& value) { + const std::string& value) { ::openmldb::api::PutRequest request; auto dim = request.add_dimensions(); dim->set_key(pk); @@ -255,8 +255,8 @@ base::Status TabletClient::Put(uint32_t tid, uint32_t pid, const std::string& pk request.set_tid(tid); request.set_pid(pid); ::openmldb::api::PutResponse response; - auto st = client_.SendRequestSt(&::openmldb::api::TabletServer_Stub::Put, - &request, &response, FLAGS_request_timeout_ms, 1); + auto st = client_.SendRequestSt(&::openmldb::api::TabletServer_Stub::Put, &request, &response, + FLAGS_request_timeout_ms, 1); if (!st.OK()) { return st; } @@ -369,7 +369,7 @@ base::Status TabletClient::LoadTable(const std::string& name, uint32_t tid, uint } base::Status TabletClient::LoadTableInternal(const ::openmldb::api::TableMeta& table_meta, - std::shared_ptr task_info) { + std::shared_ptr task_info) { ::openmldb::api::LoadTableRequest request; ::openmldb::api::TableMeta* cur_table_meta = request.mutable_table_meta(); cur_table_meta->CopyFrom(table_meta); @@ -524,7 +524,7 @@ bool TabletClient::GetManifest(uint32_t tid, uint32_t pid, ::openmldb::common::S base::Status TabletClient::GetTableStatus(::openmldb::api::GetTableStatusResponse& response) { ::openmldb::api::GetTableStatusRequest request; auto st = client_.SendRequestSt(&::openmldb::api::TabletServer_Stub::GetTableStatus, &request, &response, - FLAGS_request_timeout_ms, 1); + FLAGS_request_timeout_ms, 1); if (st.OK()) { return {response.code(), response.msg()}; } @@ -536,14 +536,14 @@ base::Status TabletClient::GetTableStatus(uint32_t tid, uint32_t pid, ::openmldb } base::Status TabletClient::GetTableStatus(uint32_t tid, uint32_t pid, bool need_schema, - ::openmldb::api::TableStatus& table_status) { + ::openmldb::api::TableStatus& table_status) { ::openmldb::api::GetTableStatusRequest request; request.set_tid(tid); request.set_pid(pid); request.set_need_schema(need_schema); ::openmldb::api::GetTableStatusResponse response; auto st = client_.SendRequestSt(&::openmldb::api::TabletServer_Stub::GetTableStatus, &request, &response, - FLAGS_request_timeout_ms, 1); + FLAGS_request_timeout_ms, 1); if (!st.OK()) { return st; } @@ -553,9 +553,10 @@ base::Status TabletClient::GetTableStatus(uint32_t tid, uint32_t pid, bool need_ return {response.code(), response.msg()}; } -std::shared_ptr TabletClient::Scan(uint32_t tid, uint32_t pid, - const std::string& pk, const std::string& idx_name, - uint64_t stime, uint64_t etime, uint32_t limit, uint32_t skip_record_num, std::string& msg) { +std::shared_ptr TabletClient::Scan(uint32_t tid, uint32_t pid, const std::string& pk, + const std::string& idx_name, uint64_t stime, + uint64_t etime, uint32_t limit, + uint32_t skip_record_num, std::string& msg) { ::openmldb::api::ScanRequest request; request.set_pk(pk); request.set_st(stime); @@ -569,7 +570,7 @@ std::shared_ptr TabletClient::Scan(uint32_t tid, request.set_skip_record_num(skip_record_num); auto response = std::make_shared(); bool ok = client_.SendRequest(&::openmldb::api::TabletServer_Stub::Scan, &request, response.get(), - FLAGS_request_timeout_ms, 1); + FLAGS_request_timeout_ms, 1); if (response->has_msg()) { msg = response->msg(); } @@ -579,9 +580,9 @@ std::shared_ptr TabletClient::Scan(uint32_t tid, return std::make_shared<::openmldb::base::ScanKvIterator>(pk, response); } -std::shared_ptr TabletClient::Scan(uint32_t tid, uint32_t pid, - const std::string& pk, const std::string& idx_name, - uint64_t stime, uint64_t etime, uint32_t limit, std::string& msg) { +std::shared_ptr TabletClient::Scan(uint32_t tid, uint32_t pid, const std::string& pk, + const std::string& idx_name, uint64_t stime, + uint64_t etime, uint32_t limit, std::string& msg) { return Scan(tid, pid, pk, idx_name, stime, etime, limit, 0, msg); } @@ -709,7 +710,7 @@ bool TabletClient::SetExpire(uint32_t tid, uint32_t pid, bool is_expire) { } base::Status TabletClient::GetTableFollower(uint32_t tid, uint32_t pid, uint64_t& offset, - std::map& info_map) { + std::map& info_map) { ::openmldb::api::GetTableFollowerRequest request; ::openmldb::api::GetTableFollowerResponse response; request.set_tid(tid); @@ -799,6 +800,57 @@ bool TabletClient::Get(uint32_t tid, uint32_t pid, const std::string& pk, uint64 return true; } +base::Status TabletClient::Get(uint32_t tid, uint32_t pid, const std::string& pk, uint64_t time, + const std::string& idx_name, std::string& value, uint64_t& ts) { + ::openmldb::api::GetRequest request; + ::openmldb::api::GetResponse response; + request.set_tid(tid); + request.set_pid(pid); + request.set_key(pk); + request.set_ts(time); + if (!idx_name.empty()) { + request.set_idx_name(idx_name); + } + auto st = client_.SendRequestSt(&::openmldb::api::TabletServer_Stub::Get, &request, &response, + FLAGS_request_timeout_ms, 1); + if (!st.OK()) { + return st; + } + + if (response.code() == 0) { + value.swap(*response.mutable_value()); + ts = response.ts(); + } + return {response.code(), response.msg()}; +} + +base::Status TabletClient::Get(uint32_t tid, uint32_t pid, const std::string& pk, uint64_t stime, api::GetType stype, + uint64_t etime, const std::string& idx_name, std::string& value, + uint64_t& ts) { + ::openmldb::api::GetRequest request; + ::openmldb::api::GetResponse response; + request.set_tid(tid); + request.set_pid(pid); + request.set_key(pk); + request.set_ts(stime); + request.set_type(stype); + request.set_et(etime); + if (!idx_name.empty()) { + request.set_idx_name(idx_name); + } + auto st = client_.SendRequestSt(&::openmldb::api::TabletServer_Stub::Get, &request, &response, + FLAGS_request_timeout_ms, 1); + if (!st.OK()) { + return st; + } + + if (response.code() == 0) { + value.swap(*response.mutable_value()); + ts = response.ts(); + } + return {response.code(), response.msg()}; +} + bool TabletClient::Delete(uint32_t tid, uint32_t pid, const std::string& pk, const std::string& idx_name, std::string& msg) { ::openmldb::api::DeleteRequest request; @@ -840,8 +892,7 @@ base::Status TabletClient::Delete(uint32_t tid, uint32_t pid, const sdk::DeleteO request.set_ts_name(option.ts_name); } request.set_enable_decode_value(option.enable_decode_value); - bool ok = client_.SendRequest(&::openmldb::api::TabletServer_Stub::Delete, &request, &response, - timeout_ms, 1); + bool ok = client_.SendRequest(&::openmldb::api::TabletServer_Stub::Delete, &request, &response, timeout_ms, 1); if (!ok || response.code() != 0) { return {base::ReturnCode::kError, response.msg()}; } @@ -885,8 +936,10 @@ bool TabletClient::DeleteBinlog(uint32_t tid, uint32_t pid, openmldb::common::St } std::shared_ptr TabletClient::Traverse(uint32_t tid, uint32_t pid, - const std::string& idx_name, const std::string& pk, uint64_t ts, uint32_t limit, bool skip_current_pk, - uint32_t ts_pos, uint32_t& count) { + const std::string& idx_name, + const std::string& pk, uint64_t ts, + uint32_t limit, bool skip_current_pk, + uint32_t ts_pos, uint32_t& count) { ::openmldb::api::TraverseRequest request; auto response = std::make_shared(); request.set_tid(tid); @@ -966,8 +1019,8 @@ bool TabletClient::AddIndex(uint32_t tid, uint32_t pid, const ::openmldb::common } bool TabletClient::AddMultiIndex(uint32_t tid, uint32_t pid, - const std::vector<::openmldb::common::ColumnKey>& column_keys, - std::shared_ptr task_info) { + const std::vector<::openmldb::common::ColumnKey>& column_keys, + std::shared_ptr task_info) { ::openmldb::api::AddIndexRequest request; ::openmldb::api::GeneralResponse response; request.set_tid(tid); @@ -1039,9 +1092,8 @@ bool TabletClient::LoadIndexData(uint32_t tid, uint32_t pid, uint32_t partition_ } bool TabletClient::ExtractIndexData(uint32_t tid, uint32_t pid, uint32_t partition_num, - const std::vector<::openmldb::common::ColumnKey>& column_key, - uint64_t offset, bool dump_data, - std::shared_ptr task_info) { + const std::vector<::openmldb::common::ColumnKey>& column_key, uint64_t offset, + bool dump_data, std::shared_ptr task_info) { if (column_key.empty()) { if (task_info) { task_info->set_status(::openmldb::api::TaskStatus::kFailed); @@ -1213,7 +1265,7 @@ bool TabletClient::CallSQLBatchRequestProcedure(const std::string& db, const std } bool static ParseBatchRequestMeta(const base::Slice& meta, const base::Slice& data, - ::openmldb::api::SQLBatchRequestQueryRequest* request) { + ::openmldb::api::SQLBatchRequestQueryRequest* request) { uint64_t total_len = 0; const int32_t* buf = reinterpret_cast(meta.data()); int32_t cnt = meta.size() / sizeof(int32_t); @@ -1238,9 +1290,9 @@ bool static ParseBatchRequestMeta(const base::Slice& meta, const base::Slice& da } base::Status TabletClient::CallSQLBatchRequestProcedure(const std::string& db, const std::string& sp_name, - const base::Slice& meta, const base::Slice& data, - bool is_debug, uint64_t timeout_ms, - brpc::Controller* cntl, openmldb::api::SQLBatchRequestQueryResponse* response) { + const base::Slice& meta, const base::Slice& data, bool is_debug, + uint64_t timeout_ms, brpc::Controller* cntl, + openmldb::api::SQLBatchRequestQueryResponse* response) { ::openmldb::api::SQLBatchRequestQueryRequest request; request.set_sp_name(sp_name); request.set_is_procedure(true); @@ -1264,10 +1316,9 @@ base::Status TabletClient::CallSQLBatchRequestProcedure(const std::string& db, c return {}; } -base::Status TabletClient::CallSQLBatchRequestProcedure(const std::string& db, const std::string& sp_name, - const base::Slice& meta, const base::Slice& data, - bool is_debug, uint64_t timeout_ms, - openmldb::RpcCallback* callback) { +base::Status TabletClient::CallSQLBatchRequestProcedure( + const std::string& db, const std::string& sp_name, const base::Slice& meta, const base::Slice& data, bool is_debug, + uint64_t timeout_ms, openmldb::RpcCallback* callback) { if (callback == nullptr) { return {base::ReturnCode::kError, "callback is null"}; } @@ -1286,8 +1337,8 @@ base::Status TabletClient::CallSQLBatchRequestProcedure(const std::string& db, c return {base::ReturnCode::kError, "append to iobuf error"}; } callback->GetController()->set_timeout_ms(timeout_ms); - if (!client_.SendRequest(&::openmldb::api::TabletServer_Stub::SQLBatchRequestQuery, - callback->GetController().get(), &request, callback->GetResponse().get(), callback)) { + if (!client_.SendRequest(&::openmldb::api::TabletServer_Stub::SQLBatchRequestQuery, callback->GetController().get(), + &request, callback->GetResponse().get(), callback)) { return {base::ReturnCode::kError, "stub is null"}; } return {}; @@ -1384,9 +1435,9 @@ bool TabletClient::DropFunction(const ::openmldb::common::ExternalFun& fun, std: return true; } -bool TabletClient::CreateAggregator(const ::openmldb::api::TableMeta& base_table_meta, - uint32_t aggr_tid, uint32_t aggr_pid, uint32_t index_pos, - const ::openmldb::base::LongWindowInfo& window_info) { +bool TabletClient::CreateAggregator(const ::openmldb::api::TableMeta& base_table_meta, uint32_t aggr_tid, + uint32_t aggr_pid, uint32_t index_pos, + const ::openmldb::base::LongWindowInfo& window_info) { ::openmldb::api::CreateAggregatorRequest request; ::openmldb::api::TableMeta* base_meta_ptr = request.mutable_base_table_meta(); base_meta_ptr->CopyFrom(base_table_meta); @@ -1412,7 +1463,7 @@ bool TabletClient::CreateAggregator(const ::openmldb::api::TableMeta& base_table bool TabletClient::GetAndFlushDeployStats(::openmldb::api::DeployStatsResponse* res) { ::openmldb::api::GAFDeployStatsRequest req; bool ok = client_.SendRequest(&::openmldb::api::TabletServer_Stub::GetAndFlushDeployStats, &req, res, - FLAGS_request_timeout_ms, FLAGS_request_max_retry); + FLAGS_request_timeout_ms, FLAGS_request_max_retry); return ok && res->code() == 0; } diff --git a/src/client/tablet_client.h b/src/client/tablet_client.h index 33188adadcc..8099c599d15 100644 --- a/src/client/tablet_client.h +++ b/src/client/tablet_client.h @@ -78,20 +78,24 @@ class TabletClient : public Client { base::Status Put(uint32_t tid, uint32_t pid, const std::string& pk, uint64_t time, const std::string& value); base::Status Put(uint32_t tid, uint32_t pid, uint64_t time, const std::string& value, - const std::vector>& dimensions, - int memory_usage_limit = 0, bool put_if_absent = false); + const std::vector>& dimensions, int memory_usage_limit = 0, + bool put_if_absent = false, bool check_exists = false); base::Status Put(uint32_t tid, uint32_t pid, uint64_t time, const base::Slice& value, - ::google::protobuf::RepeatedPtrField<::openmldb::api::Dimension>* dimensions, - int memory_usage_limit = 0, bool put_if_absent = false); + ::google::protobuf::RepeatedPtrField<::openmldb::api::Dimension>* dimensions, + int memory_usage_limit = 0, bool put_if_absent = false, bool check_exists = false); bool Get(uint32_t tid, uint32_t pid, const std::string& pk, uint64_t time, std::string& value, // NOLINT uint64_t& ts, // NOLINT - std::string& msg); // NOLINT + std::string& msg); // NOLINT bool Get(uint32_t tid, uint32_t pid, const std::string& pk, uint64_t time, const std::string& idx_name, std::string& value, uint64_t& ts, std::string& msg); // NOLINT - + base::Status Get(uint32_t tid, uint32_t pid, const std::string& pk, uint64_t time, const std::string& idx_name, + std::string& value, uint64_t& ts); // NOLINT + base::Status Get(uint32_t tid, uint32_t pid, const std::string& pk, uint64_t stime, api::GetType stype, + uint64_t etime, const std::string& idx_name, std::string& value, + uint64_t& ts); // NOLINT bool Delete(uint32_t tid, uint32_t pid, const std::string& pk, const std::string& idx_name, std::string& msg); // NOLINT diff --git a/src/cmd/display.h b/src/cmd/display.h index 34e1f851e39..0d7d2819964 100644 --- a/src/cmd/display.h +++ b/src/cmd/display.h @@ -105,6 +105,7 @@ __attribute__((unused)) static void PrintColumnKey( t.add("ts"); t.add("ttl"); t.add("ttl_type"); + t.add("type"); t.end_of_row(); int index_pos = 1; for (int i = 0; i < column_key_field.size(); i++) { @@ -141,7 +142,7 @@ __attribute__((unused)) static void PrintColumnKey( t.add("-"); // ttl t.add("-"); // ttl_type } - + t.add(common::IndexType_Name(column_key.type())); t.end_of_row(); } stream << t; diff --git a/src/cmd/sql_cmd_test.cc b/src/cmd/sql_cmd_test.cc index 79225eb52dd..0b29ae449cd 100644 --- a/src/cmd/sql_cmd_test.cc +++ b/src/cmd/sql_cmd_test.cc @@ -1284,7 +1284,7 @@ TEST_P(DBSDKTest, Truncate) { sr->ExecuteSQL(absl::StrCat("insert into ", table_name, " values ('", key, "', 11, ", ts, ");"), &status); } } - absl::SleepFor(absl::Seconds(5)); + absl::SleepFor(absl::Seconds(16)); // sleep more to avoid truncate failed on partition offset mismatch res = sr->ExecuteSQL(absl::StrCat("select * from ", table_name, ";"), &status); ASSERT_EQ(res->Size(), 100); @@ -1556,18 +1556,18 @@ TEST_P(DBSDKTest, SQLDeletetRow) { res = sr->ExecuteSQL(absl::StrCat("select * from ", table_name, ";"), &status); ASSERT_EQ(res->Size(), 3); std::string delete_sql = "delete from " + table_name + " where c1 = ?;"; - auto insert_row = sr->GetDeleteRow(db_name, delete_sql, &status); + auto delete_row = sr->GetDeleteRow(db_name, delete_sql, &status); ASSERT_TRUE(status.IsOK()); - insert_row->SetString(1, "key3"); - ASSERT_TRUE(insert_row->Build()); - sr->ExecuteDelete(insert_row, &status); + delete_row->SetString(1, "key3"); + ASSERT_TRUE(delete_row->Build()); + sr->ExecuteDelete(delete_row, &status); ASSERT_TRUE(status.IsOK()); res = sr->ExecuteSQL(absl::StrCat("select * from ", table_name, ";"), &status); ASSERT_EQ(res->Size(), 2); - insert_row->Reset(); - insert_row->SetString(1, "key100"); - ASSERT_TRUE(insert_row->Build()); - sr->ExecuteDelete(insert_row, &status); + delete_row->Reset(); + delete_row->SetString(1, "key100"); + ASSERT_TRUE(delete_row->Build()); + sr->ExecuteDelete(delete_row, &status); ASSERT_TRUE(status.IsOK()); res = sr->ExecuteSQL(absl::StrCat("select * from ", table_name, ";"), &status); ASSERT_EQ(res->Size(), 2); diff --git a/src/codec/field_codec.h b/src/codec/field_codec.h index 452578ff9fc..14f5ec14a5a 100644 --- a/src/codec/field_codec.h +++ b/src/codec/field_codec.h @@ -35,8 +35,8 @@ namespace codec { template static bool AppendColumnValue(const std::string& v, hybridse::sdk::DataType type, bool is_not_null, const std::string& null_value, T row) { - // check if null - if (v == null_value) { + // check if null, empty string will cast fail and throw bad_lexical_cast + if (v.empty() || v == null_value) { if (is_not_null) { return false; } diff --git a/src/flags.cc b/src/flags.cc index 2a061dbd263..744da1dac64 100644 --- a/src/flags.cc +++ b/src/flags.cc @@ -187,3 +187,7 @@ DEFINE_int32(sync_job_timeout, 30 * 60 * 1000, "sync job timeout, unit is milliseconds, should <= server.channel_keep_alive_time in TaskManager"); DEFINE_int32(deploy_job_max_wait_time_ms, 30 * 60 * 1000, "the max wait time of waiting deploy job"); DEFINE_bool(skip_grant_tables, true, "skip the grant tables"); + +// iot +// not exactly size, may plus some TODO(hw): too small? +DEFINE_uint32(cidx_gc_max_size, 1000, "config the max size for one cidx segment gc"); diff --git a/src/nameserver/name_server_impl.cc b/src/nameserver/name_server_impl.cc index 5e65a7d2d94..6aabec47d1f 100644 --- a/src/nameserver/name_server_impl.cc +++ b/src/nameserver/name_server_impl.cc @@ -1564,8 +1564,7 @@ bool NameServerImpl::Init(const std::string& zk_cluster, const std::string& zk_p task_thread_pool_.DelayTask(FLAGS_make_snapshot_check_interval, boost::bind(&NameServerImpl::SchedMakeSnapshot, this)); std::shared_ptr<::openmldb::nameserver::TableInfo> table_info; - while ( - !GetTableInfo(::openmldb::nameserver::USER_INFO_NAME, ::openmldb::nameserver::INTERNAL_DB, &table_info)) { + while (!GetTableInfo(::openmldb::nameserver::USER_INFO_NAME, ::openmldb::nameserver::INTERNAL_DB, &table_info)) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); } return true; @@ -3818,6 +3817,8 @@ void NameServerImpl::CreateTable(RpcController* controller, const CreateTableReq table_info->set_partition_num(1); table_info->set_replica_num(1); } + // TODO(hw): valid index pattern 1. all covering 2. clustered + secondary/covering(only one clustered and it should + // be the first one) auto status = schema::SchemaAdapter::CheckTableMeta(*table_info); if (!status.OK()) { PDLOG(WARNING, status.msg.c_str()); @@ -8675,6 +8676,11 @@ void NameServerImpl::AddIndex(RpcController* controller, const AddIndexRequest* std::vector<::openmldb::common::ColumnKey> column_key_vec; if (request->column_keys_size() > 0) { for (const auto& column_key : request->column_keys()) { + if (column_key.type() == common::IndexType::kClustered) { + base::SetResponseStatus(ReturnCode::kWrongColumnKey, "add clustered index is not allowed", response); + LOG(WARNING) << "add clustered index is not allowed"; + return; + } column_key_vec.push_back(column_key); } } else { @@ -9530,8 +9536,8 @@ base::Status NameServerImpl::CreateProcedureOnTablet(const ::openmldb::api::Crea ", endpoint: ", tb_client->GetEndpoint(), ", msg: ", status.GetMsg())}; } DLOG(INFO) << "create procedure on tablet success. db_name: " << sp_info.db_name() << ", " - << "sp_name: " << sp_info.sp_name() << ", " << "sql: " << sp_info.sql() - << "endpoint: " << tb_client->GetEndpoint(); + << "sp_name: " << sp_info.sp_name() << ", " + << "sql: " << sp_info.sql() << "endpoint: " << tb_client->GetEndpoint(); } return {}; } @@ -10120,11 +10126,7 @@ void NameServerImpl::ShowFunction(RpcController* controller, const ShowFunctionR base::Status NameServerImpl::InitGlobalVarTable() { std::map default_value = { - {"execute_mode", "online"}, - {"enable_trace", "false"}, - {"sync_job", "false"}, - {"job_timeout", "20000"} - }; + {"execute_mode", "online"}, {"enable_trace", "false"}, {"sync_job", "false"}, {"job_timeout", "20000"}}; // get table_info std::string db = INFORMATION_SCHEMA_DB; std::string table = GLOBAL_VARIABLES; diff --git a/src/proto/common.proto b/src/proto/common.proto index 8241e646f34..ee4c23e1c68 100755 --- a/src/proto/common.proto +++ b/src/proto/common.proto @@ -64,12 +64,19 @@ message TTLSt { optional uint64 lat_ttl = 3 [default = 0]; } +enum IndexType { + kCovering = 0; + kClustered = 1; + kSecondary = 2; +} + message ColumnKey { optional string index_name = 1; repeated string col_name = 2; optional string ts_name = 3; optional uint32 flag = 4 [default = 0]; // 0 mean index exist, 1 mean index has been deleted optional TTLSt ttl = 5; + optional IndexType type = 6 [default = kCovering]; } message EndpointAndTid { diff --git a/src/proto/tablet.proto b/src/proto/tablet.proto index bc160a01f1e..253eb35b33e 100755 --- a/src/proto/tablet.proto +++ b/src/proto/tablet.proto @@ -196,6 +196,7 @@ message PutRequest { optional uint32 format_version = 8 [default = 0, deprecated = true]; optional uint32 memory_limit = 9; optional bool put_if_absent = 10 [default = false]; + optional bool check_exists = 11 [default = false]; } message PutResponse { diff --git a/src/sdk/node_adapter.cc b/src/sdk/node_adapter.cc index c7c0d191922..d6b3979cfe7 100644 --- a/src/sdk/node_adapter.cc +++ b/src/sdk/node_adapter.cc @@ -383,6 +383,7 @@ bool NodeAdapter::TransformToTableDef(::hybridse::node::CreatePlanNode* create_n if (!TransformToColumnKey(column_index, column_names, index, status)) { return false; } + DLOG(INFO) << "index column key [" << index->ShortDebugString() << "]"; break; } @@ -471,6 +472,12 @@ bool NodeAdapter::TransformToColumnKey(hybridse::node::ColumnIndexNode* column_i for (const auto& key : column_index->GetKey()) { index->add_col_name(key); } + auto& type = column_index->GetIndexType(); + if (type == "skey") { + index->set_type(common::IndexType::kSecondary); + } else if (type == "ckey") { + index->set_type(common::IndexType::kClustered); + } // else default type kCovering // if no column_names, skip check if (!column_names.empty()) { for (const auto& col : index->col_name()) { diff --git a/src/sdk/option.h b/src/sdk/option.h index 3acb4e30afa..ae6fef7dfac 100644 --- a/src/sdk/option.h +++ b/src/sdk/option.h @@ -17,16 +17,26 @@ #ifndef SRC_SDK_OPTION_H_ #define SRC_SDK_OPTION_H_ +#include #include +#include "absl/strings/str_cat.h" + namespace openmldb { namespace sdk { struct DeleteOption { DeleteOption(std::optional idx_i, const std::string& key_i, const std::string& ts_name_i, - std::optional start_ts_i, std::optional end_ts_i) : - idx(idx_i), key(key_i), ts_name(ts_name_i), start_ts(start_ts_i), end_ts(end_ts_i) {} + std::optional start_ts_i, std::optional end_ts_i) + : idx(idx_i), key(key_i), ts_name(ts_name_i), start_ts(start_ts_i), end_ts(end_ts_i) {} DeleteOption() = default; + std::string DebugString() { + return absl::StrCat("idx: ", idx.has_value() ? std::to_string(idx.value()) : "-1", ", key: ", key, + ", ts_name: ", ts_name, + ", start_ts: ", start_ts.has_value() ? std::to_string(start_ts.value()) : "-1", + ", end_ts: ", end_ts.has_value() ? std::to_string(end_ts.value()) : "-1", + ", enable_decode_value: ", enable_decode_value ? "true" : "false"); + } std::optional idx = std::nullopt; std::string key; std::string ts_name; diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index 3d09156fdcc..8ef74f8ac2d 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -37,6 +37,7 @@ #include "base/file_util.h" #include "base/glog_wrapper.h" #include "base/status_util.h" +#include "base/index_util.h" #include "boost/none.hpp" #include "boost/property_tree/ini_parser.hpp" #include "boost/property_tree/ptree.hpp" @@ -63,6 +64,7 @@ #include "sdk/result_set_sql.h" #include "sdk/sdk_util.h" #include "sdk/split.h" +#include "storage/index_organized_table.h" #include "udf/udf.h" #include "vm/catalog.h" #include "vm/engine.h" @@ -1284,12 +1286,11 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& s std::vector fails; if (!codegen_rows.empty()) { - for (size_t i = 0 ; i < codegen_rows.size(); ++i) { + for (size_t i = 0; i < codegen_rows.size(); ++i) { auto r = codegen_rows[i]; auto row = std::make_shared(table_info, schema, r, put_if_absent); if (!PutRow(table_info->tid(), row, tablets, status)) { - LOG(WARNING) << "fail to put row[" - << "] due to: " << status->msg; + LOG(WARNING) << "fail to put row[" << i << "] due to: " << status->msg; fails.push_back(i); continue; } @@ -1323,12 +1324,156 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& s return true; } +bool IsIOT(const nameserver::TableInfo& table_info) { + auto& cks = table_info.column_key(); + if (cks.empty()) { + LOG(WARNING) << "no index in meta"; + return false; + } + if (cks[0].has_type() && cks[0].type() == common::IndexType::kClustered) { + // check other indexes + for (int i = 1; i < cks.size(); i++) { + if (cks[i].has_type() && cks[i].type() == common::IndexType::kClustered) { + LOG(WARNING) << "should be only one clustered index"; + return false; + } + } + return true; + } + return false; +} + +// clustered index idx must be 0 +bool IsClusteredIndexIdx(const openmldb::api::Dimension& dim_index) { return dim_index.idx() == 0; } +bool IsClusteredIndexIdx(const std::pair& dim_index) { return dim_index.second == 0; } + +std::string ClusteredIndexTsName(const nameserver::TableInfo& table_info) { + auto& cks = table_info.column_key(); + if (cks.empty()) { + LOG(WARNING) << "no index in meta"; + return ""; + } + if (cks[0].has_ts_name() && cks[0].ts_name() != storage::DEFAULT_TS_COL_NAME) { + return cks[0].ts_name(); + } + // if default ts col, return empty string + return ""; +} + bool SQLClusterRouter::PutRow(uint32_t tid, const std::shared_ptr& row, const std::vector>& tablets, ::hybridse::sdk::Status* status) { RET_FALSE_IF_NULL_AND_WARN(status, "output status is nullptr"); const auto& dimensions = row->GetDimensions(); uint64_t cur_ts = ::baidu::common::timer::get_micros() / 1000; + // if iot, check if primary key exists in cidx + if (IsIOT(row->GetTableInfo())) { + if (row->IsPutIfAbsent()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "put_if_absent is not supported for iot table"); + return false; + } + // dimensions map>, find the idxid == 0 + bool valid = false; + uint64_t ts = 0; + std::string exists_value; // if empty, no primary key exists + + auto cols = row->GetTableInfo().column_desc(); // copy + codec::RowView row_view(cols); + // get cidx pid tablet for existence check + for (const auto& kv : dimensions) { + uint32_t pid = kv.first; + for (auto& pair : kv.second) { + if (IsClusteredIndexIdx(pair)) { + // check if primary key exists on tablet + auto tablet = tablets[pid]; + if (!tablet) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, + "tablet accessor is nullptr, can't check clustered index"); + return false; + } + auto client = tablet->GetClient(); + if (!client) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, + "tablet client is nullptr, can't check clustered index"); + return false; + } + int64_t get_ts = 0; + auto ts_name = ClusteredIndexTsName(row->GetTableInfo()); + if (!ts_name.empty()) { + bool found = false; + for (int i = 0; i < cols.size(); i++) { + if (cols.Get(i).name() == ts_name) { + row_view.GetInteger(reinterpret_cast(row->GetRow().c_str()), i, + cols.Get(i).data_type(), &get_ts); + found = true; + break; + } + } + if (!found || get_ts < 0) { + SET_STATUS_AND_WARN( + status, StatusCode::kCmdError, + found ? "invalid ts " + std::to_string(get_ts) : "get ts column failed"); + return false; + } + } else { + DLOG(INFO) << "no ts column in cidx"; + } + // if get_ts == 0, cidx may be without ts column, you should check ts col in cidx info, not by + // get_ts + DLOG(INFO) << "get primary key on iot table, pid " << pid << ", key " << pair.first << ", ts " + << get_ts; + // get rpc can't read all data(expired data may still in data skiplist), so we use put to check + // exists only check in cidx, no real insertion. get_ts may not be the current time, it can be ts + // col value, it's a bit different. + auto st = client->Put(tid, pid, get_ts, row->GetRow(), {pair}, + insert_memory_usage_limit_.load(std::memory_order_relaxed), false, true); + if (!st.OK() && st.GetCode() != base::ReturnCode::kKeyNotFound) { + APPEND_FROM_BASE_AND_WARN(status, st, "get primary key failed"); + return false; + } + valid = true; + DLOG(INFO) << "Get result: " << st.ToString(); + // check result, won't set exists_value if key not found + if (st.OK()) { + DLOG(INFO) << "primary key exists on iot table"; + exists_value = row->GetRow(); + ts = get_ts; + } + } + } + } + if (!valid) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, + "can't check primary key on iot table, meta/connection error"); + return false; + } + DLOG_IF(INFO, exists_value.empty()) << "primary key not exists, safe to insert"; + if (!exists_value.empty()) { + // delete old data then insert new data, no concurrency control, be careful + // revertput or SQLDeleteRow is not easy to use here, so make a sql + DLOG(INFO) << "primary key exists, delete old data then insert new data"; + // just where primary key, not all columns(redundant condition) + auto hint = base::MakePkeysHint(row->GetTableInfo().column_desc(), + row->GetTableInfo().column_key(0)); + if (hint.empty()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "make pkeys hint failed"); + return false; + } + auto sql = base::MakeDeleteSQL(row->GetTableInfo().db(), row->GetTableInfo().name(), + row->GetTableInfo().column_key(0), + (int8_t*)exists_value.c_str(), ts, row_view, hint); + if (sql.empty()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "make delete sql failed"); + return false; + } + ExecuteSQL(sql, status); + if (status->code != 0) { + PREPEND_AND_WARN(status, "delete old data failed"); + return false; + } + DLOG(INFO) << "delete old data success"; + } + } for (const auto& kv : dimensions) { uint32_t pid = kv.first; if (pid < tablets.size()) { @@ -1342,16 +1487,17 @@ bool SQLClusterRouter::PutRow(uint32_t tid, const std::shared_ptr& client->Put(tid, pid, cur_ts, row->GetRow(), kv.second, insert_memory_usage_limit_.load(std::memory_order_relaxed), row->IsPutIfAbsent()); if (!ret.OK()) { - if (RevertPut(row->GetTableInfo(), pid, dimensions, cur_ts, base::Slice(row->GetRow()), tablets) - .IsOK()) { - SET_STATUS_AND_WARN(status, StatusCode::kCmdError, - absl::StrCat("INSERT failed, tid ", tid)); + APPEND_FROM_BASE(status, ret, "put failed"); + if (auto rp = RevertPut(row->GetTableInfo(), pid, dimensions, cur_ts, + base::Slice(row->GetRow()), tablets); + rp.IsOK()) { + APPEND_AND_WARN(status, "tid " + std::to_string(tid) + ". RevertPut success."); } else { - SET_STATUS_AND_WARN(status, StatusCode::kCmdError, - "INSERT failed, tid " + std::to_string(tid) + - ". Note that data might have been partially inserted. " - "You are encouraged to perform DELETE to remove any partially " - "inserted data before trying INSERT again."); + APPEND_AND_WARN(status, "tid " + std::to_string(tid) + + ". RevertPut failed: " + rp.ToString() + + "Note that data might have been partially inserted. " + "You are encouraged to perform DELETE to remove any " + "partially inserted data before trying INSERT again."); } return false; } @@ -1431,7 +1577,7 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& n std::vector> tablets; bool ret = cluster_sdk_->GetTablet(db, name, &tablets); if (!ret || tablets.empty()) { - status->msg = "fail to get table " + name + " tablet"; + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "fail to get table " + name + " tablet"); return false; } std::map> dimensions_map; @@ -1454,6 +1600,114 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& n } base::Slice row_value(value, len); uint64_t cur_ts = ::baidu::common::timer::get_micros() / 1000; + // TODO(hw): refactor with PutRow + // if iot, check if primary key exists in cidx + auto table_info = cluster_sdk_->GetTableInfo(db, name); + if (!table_info) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "fail to get table info"); + return false; + } + if (IsIOT(*table_info)) { + if (put_if_absent) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "put_if_absent is not supported for iot table"); + return false; + } + // dimensions map>, find the idxid == 0 + bool valid = false; + uint64_t ts = 0; + std::string exists_value; // if empty, no primary key exists + // TODO: ref putrow, fix later + auto cols = table_info->column_desc(); // copy + codec::RowView row_view(cols); + for (const auto& kv : dimensions_map) { + uint32_t pid = kv.first; + for (auto& pair : kv.second) { + if (IsClusteredIndexIdx(pair)) { + // check if primary key exists on tablet + auto tablet = tablets[pid]; + if (!tablet) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, + "tablet accessor is nullptr, can't check clustered index"); + return false; + } + auto client = tablet->GetClient(); + if (!client) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, + "tablet client is nullptr, can't check clustered index"); + return false; + } + int64_t get_ts = 0; + if (auto ts_name = ClusteredIndexTsName(*table_info); !ts_name.empty()) { + bool found = false; + for (int i = 0; i < cols.size(); i++) { + if (cols.Get(i).name() == ts_name) { + row_view.GetInteger(reinterpret_cast(value), i, cols.Get(i).data_type(), + &get_ts); + found = true; + break; + } + } + if (!found || get_ts < 0) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, + found ? "invalid ts" + std::to_string(get_ts) : "get ts column failed"); + return false; + } + } else { + DLOG(INFO) << "no ts column in cidx"; + } + // if get_ts == 0, may be cidx without ts column + DLOG(INFO) << "get key " << pair.key() << ", ts " << get_ts; + ::google::protobuf::RepeatedPtrField<::openmldb::api::Dimension> dims; + dims.Add()->CopyFrom(pair); + auto st = client->Put(tid, pid, get_ts, row_value, &dims, + insert_memory_usage_limit_.load(std::memory_order_relaxed), false, true); + if (!st.OK() && st.GetCode() != base::ReturnCode::kKeyNotFound) { + APPEND_FROM_BASE_AND_WARN(status, st, "get primary key failed"); + return false; + } + valid = true; + DLOG(INFO) << "Get result: " << st.ToString(); + // check result, won't set exists_value if key not found + if (st.OK()) { + DLOG(INFO) << "primary key exist on iot table"; + exists_value = value; + ts = get_ts; + } + } + } + } + if (!valid) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, + "can't check primary key on iot table, meta/connection error"); + return false; + } + DLOG_IF(INFO, exists_value.empty()) << "primary key not exists, safe to insert"; + if (!exists_value.empty()) { + // delete old data then insert new data, no concurrency control, be careful + // revertput or SQLDeleteRow is not easy to use here, so make a sql? + DLOG(INFO) << "primary key exists, delete old data then insert new data"; + // just where primary key, not all columns(redundant condition) + auto hint = + base::MakePkeysHint(table_info->column_desc(), table_info->column_key(0)); + if (hint.empty()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "make pkeys hint failed"); + return false; + } + auto sql = base::MakeDeleteSQL(table_info->db(), table_info->name(), + table_info->column_key(0), + (int8_t*)exists_value.c_str(), ts, row_view, hint); + if (sql.empty()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "make delete sql failed"); + return false; + } + ExecuteSQL(sql, status); + if (status->code != 0) { + PREPEND_AND_WARN(status, "delete old data failed"); + return false; + } + } + } + for (auto& kv : dimensions_map) { uint32_t pid = kv.first; if (pid < tablets.size()) { @@ -1461,17 +1715,13 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& n if (tablet) { auto client = tablet->GetClient(); if (client) { - DLOG(INFO) << "put data to endpoint " << client->GetEndpoint() << " with dimensions size " - << kv.second.size(); + DVLOG(3) << "put data to endpoint " << client->GetEndpoint() << " with dimensions size " + << kv.second.size(); auto ret = client->Put(tid, pid, cur_ts, row_value, &kv.second, insert_memory_usage_limit_.load(std::memory_order_relaxed), put_if_absent); if (!ret.OK()) { // TODO(hw): show put failed row(readable)? ::hybridse::codec::RowView::GetRowString? - SET_STATUS_AND_WARN(status, StatusCode::kCmdError, - "INSERT failed, tid " + std::to_string(tid) + - ". Note that data might have been partially inserted. " - "You are encouraged to perform DELETE to remove any partially " - "inserted data before trying INSERT again."); + APPEND_FROM_BASE(status, ret, "INSERT failed, tid " + std::to_string(tid)); std::map>> dimensions; for (const auto& val : dimensions_map) { std::vector> vec; @@ -1480,14 +1730,14 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& n } dimensions.emplace(val.first, std::move(vec)); } - auto table_info = cluster_sdk_->GetTableInfo(db, name); - if (!table_info) { - return false; - } + // TODO(hw): better to return absl::Status - if (RevertPut(*table_info, pid, dimensions, cur_ts, row_value, tablets).IsOK()) { - SET_STATUS_AND_WARN(status, StatusCode::kCmdError, - absl::StrCat("INSERT failed, tid ", tid)); + if (auto rp = RevertPut(*table_info, pid, dimensions, cur_ts, row_value, tablets); rp.IsOK()) { + APPEND_AND_WARN(status, "revert ok"); + } else { + APPEND_AND_WARN(status, + "revert failed. You are encouraged to perform DELETE to remove any " + "partially inserted data before trying INSERT again."); } return false; } @@ -1699,7 +1949,7 @@ std::shared_ptr SQLClusterRouter::HandleSQLCmd(const h } case hybridse::node::kCmdShowUser: { - std::vector value = { options_->user }; + std::vector value = {options_->user}; return ResultSetSQL::MakeResultSet({"User"}, {value}, status); } @@ -2120,6 +2370,32 @@ std::shared_ptr SQLClusterRouter::HandleSQLCmd(const h return {}; } +base::Status ValidateTableInfo(const nameserver::TableInfo& table_info) { + auto& indexs = table_info.column_key(); + if (indexs.empty()) { + LOG(INFO) << "no index specified, it'll add default index later"; + return {}; + } + if (indexs[0].type() == common::IndexType::kCovering) { + // MemTable, all other indexs should be covering + for (int i = 1; i < indexs.size(); i++) { + if (indexs[i].type() != common::IndexType::kCovering) { + return {base::ReturnCode::kInvalidArgs, "index " + std::to_string(i) + " should be covering"}; + } + } + } else if (indexs[0].type() == common::IndexType::kClustered) { + // IOT, no more clustered index, secondary and covering are valid + for (int i = 1; i < indexs.size(); i++) { + if (indexs[i].type() == common::IndexType::kClustered) { + return {base::ReturnCode::kInvalidArgs, "index " + std::to_string(i) + " should not be clustered"}; + } + } + } else { + return {base::ReturnCode::kInvalidArgs, "index 0 should be clustered or covering"}; + } + return {}; +} + base::Status SQLClusterRouter::HandleSQLCreateTable(hybridse::node::CreatePlanNode* create_node, const std::string& db, std::shared_ptr<::openmldb::client::NsClient> ns_ptr) { return HandleSQLCreateTable(create_node, db, ns_ptr, ""); @@ -2148,11 +2424,16 @@ base::Status SQLClusterRouter::HandleSQLCreateTable(hybridse::node::CreatePlanNo hybridse::base::Status sql_status; bool is_cluster_mode = cluster_sdk_->IsClusterMode(); + ::openmldb::sdk::NodeAdapter::TransformToTableDef(create_node, &table_info, default_replica_num, is_cluster_mode, &sql_status); if (sql_status.code != 0) { return base::Status(sql_status.code, sql_status.msg); } + // clustered should be the first index, user should set it, we don't adjust it + if (auto st = ValidateTableInfo(table_info); !st.OK()) { + return st; + } std::string msg; if (!ns_ptr->CreateTable(table_info, create_node->GetIfNotExist(), msg)) { return base::Status(base::ReturnCode::kSQLCmdRunError, msg); @@ -2764,7 +3045,8 @@ std::shared_ptr SQLClusterRouter::ExecuteSQL( } case hybridse::node::kPlanTypeCreateUser: { auto create_node = dynamic_cast(node); - UserInfo user_info;; + UserInfo user_info; + auto result = GetUser(create_node->Name(), &user_info); if (!result.ok()) { *status = {StatusCode::kCmdError, result.status().message()}; @@ -3005,7 +3287,7 @@ std::shared_ptr SQLClusterRouter::ExecuteSQL( if (is_online_mode) { // Handle in online mode config.emplace("spark.insert_memory_usage_limit", - std::to_string(insert_memory_usage_limit_.load(std::memory_order_relaxed))); + std::to_string(insert_memory_usage_limit_.load(std::memory_order_relaxed))); base_status = ImportOnlineData(sql, config, database, is_sync_job, offline_job_timeout, &job_info); } else { // Handle in offline mode @@ -3540,10 +3822,11 @@ hybridse::sdk::Status SQLClusterRouter::LoadDataMultipleFile(int id, int step, c const std::vector& file_list, const openmldb::sdk::LoadOptionsMapParser& options_parser, uint64_t* count) { + *count = 0; for (const auto& file : file_list) { uint64_t cur_count = 0; auto status = LoadDataSingleFile(id, step, database, table, file, options_parser, &cur_count); - DLOG(INFO) << "[thread " << id << "] Loaded " << count << " rows in " << file; + DLOG(INFO) << "[thread " << id << "] Loaded " << cur_count << " rows in " << file; if (!status.IsOK()) { return status; } @@ -3712,6 +3995,7 @@ hybridse::sdk::Status SQLClusterRouter::HandleDelete(const std::string& db, cons if (!status.IsOK()) { return status; } + DLOG(INFO) << "delete option: " << option.DebugString(); status = SendDeleteRequst(table_info, option); if (status.IsOK() && db != nameserver::INTERNAL_DB) { status = { @@ -4959,10 +5243,10 @@ std::shared_ptr SQLClusterRouter::GetNameServerJobResu } absl::StatusOr SQLClusterRouter::GetUser(const std::string& name, UserInfo* user_info) { - std::string sql = absl::StrCat("select * from ", nameserver::USER_INFO_NAME); + std::string sql = absl::StrCat("select * from ", nameserver::USER_INFO_NAME); hybridse::sdk::Status status; - auto rs = ExecuteSQLParameterized(nameserver::INTERNAL_DB, sql, - std::shared_ptr(), &status); + auto rs = + ExecuteSQLParameterized(nameserver::INTERNAL_DB, sql, std::shared_ptr(), &status); if (rs == nullptr) { return absl::InternalError(status.msg); } @@ -5014,6 +5298,8 @@ hybridse::sdk::Status SQLClusterRouter::UpdateUser(const UserInfo& user_info, co } hybridse::sdk::Status SQLClusterRouter::DeleteUser(const std::string& name) { + std::string sql = + absl::StrCat("delete from ", nameserver::USER_INFO_NAME, " where host = '%' and user = '", name, "';"); hybridse::sdk::Status status; auto ns_client = cluster_sdk_->GetNsClient(); @@ -5035,12 +5321,10 @@ void SQLClusterRouter::AddUserToConfig(std::map* confi } } -::hybridse::sdk::Status SQLClusterRouter::RevertPut(const nameserver::TableInfo& table_info, - uint32_t end_pid, - const std::map>>& dimensions, - uint64_t ts, - const base::Slice& value, - const std::vector>& tablets) { +::hybridse::sdk::Status SQLClusterRouter::RevertPut( + const nameserver::TableInfo& table_info, uint32_t end_pid, + const std::map>>& dimensions, uint64_t ts, + const base::Slice& value, const std::vector>& tablets) { codec::RowView row_view(table_info.column_desc()); std::map column_map; for (int32_t i = 0; i < table_info.column_desc_size(); i++) { diff --git a/src/storage/index_organized_table.cc b/src/storage/index_organized_table.cc new file mode 100644 index 00000000000..aeb3302b22b --- /dev/null +++ b/src/storage/index_organized_table.cc @@ -0,0 +1,634 @@ +/* + * Copyright 2021 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "storage/index_organized_table.h" + +#include + +#include "absl/strings/str_join.h" // dlog +#include "absl/strings/str_split.h" +#include "sdk/sql_router.h" +#include "storage/iot_segment.h" +#include "base/index_util.h" + +DECLARE_uint32(absolute_default_skiplist_height); + +namespace openmldb::storage { + +IOTIterator* NewNullIterator() { + // if TimeEntries::Iterator is null, nothing will be used + return new IOTIterator(nullptr, type::CompressType::kNoCompress, {}); +} + +// TODO(hw): temp func to create iot iterator +IOTIterator* NewIOTIterator(Segment* segment, const Slice& key, Ticket& ticket, type::CompressType compress_type, + std::unique_ptr cidx_iter) { + void* entry = nullptr; + auto entries = segment->GetKeyEntries(); + if (entries == nullptr || segment->GetTsCnt() > 1 || entries->Get(key, entry) < 0 || entry == nullptr) { + return NewNullIterator(); + } + ticket.Push(reinterpret_cast(entry)); + return new IOTIterator(reinterpret_cast(entry)->entries.NewIterator(), compress_type, + std::move(cidx_iter)); +} + +IOTIterator* NewIOTIterator(Segment* segment, const Slice& key, uint32_t idx, Ticket& ticket, + type::CompressType compress_type, + std::unique_ptr cidx_iter) { + auto ts_idx_map = segment->GetTsIdxMap(); + auto pos = ts_idx_map.find(idx); + if (pos == ts_idx_map.end()) { + LOG(WARNING) << "can't find idx in segment"; + return NewNullIterator(); + } + auto entries = segment->GetKeyEntries(); + if (segment->GetTsCnt() == 1) { + return NewIOTIterator(segment, key, ticket, compress_type, std::move(cidx_iter)); + } + void* entry_arr = nullptr; + if (entries->Get(key, entry_arr) < 0 || entry_arr == nullptr) { + return NewNullIterator(); + } + auto entry = reinterpret_cast(entry_arr)[pos->second]; + ticket.Push(entry); + return new IOTIterator(entry->entries.NewIterator(), compress_type, std::move(cidx_iter)); +} + +TableIterator* IndexOrganizedTable::NewIterator(uint32_t index, const std::string& pk, Ticket& ticket) { + std::shared_ptr index_def = table_index_.GetIndex(index); + if (!index_def || !index_def->IsReady()) { + LOG(WARNING) << "index is invalid"; + return nullptr; + } + DLOG(INFO) << "new iter for index and pk " << index << " name " << index_def->GetName(); + uint32_t seg_idx = SegIdx(pk); + Slice spk(pk); + uint32_t real_idx = index_def->GetInnerPos(); + Segment* segment = GetSegment(real_idx, seg_idx); + auto ts_col = index_def->GetTsColumn(); + if (ts_col) { + // if secondary, use iot iterator + if (index_def->IsSecondaryIndex()) { + // get clustered index iter for secondary index + auto handler = catalog_->GetTable(GetDB(), GetName()); + if (!handler) { + LOG(WARNING) << "no TableHandler for " << GetDB() << "." << GetName(); + return nullptr; + } + auto tablet_table_handler = std::dynamic_pointer_cast(handler); + if (!tablet_table_handler) { + LOG(WARNING) << "convert TabletTableHandler failed for " << GetDB() << "." << GetName(); + return nullptr; + } + LOG(INFO) << "create iot iterator for pk"; + // TODO(hw): iter may be invalid if catalog updated + auto iter = + NewIOTIterator(segment, spk, ts_col->GetId(), ticket, GetCompressType(), + std::move(tablet_table_handler->GetWindowIterator(table_index_.GetIndex(0)->GetName()))); + return iter; + } + // clsutered and covering still use old iterator + return segment->NewIterator(spk, ts_col->GetId(), ticket, GetCompressType()); + } + // cidx without ts? or invalid case + DLOG(INFO) << "index ts col is null, reate no-ts iterator"; + // TODO(hw): sidx without ts? + return segment->NewIterator(spk, ticket, GetCompressType()); +} + +TraverseIterator* IndexOrganizedTable::NewTraverseIterator(uint32_t index) { + std::shared_ptr index_def = GetIndex(index); + if (!index_def || !index_def->IsReady()) { + PDLOG(WARNING, "index %u not found. tid %u pid %u", index, id_, pid_); + return nullptr; + } + DLOG(INFO) << "new traverse iter for index " << index << " name " << index_def->GetName(); + uint64_t expire_time = 0; + uint64_t expire_cnt = 0; + auto ttl = index_def->GetTTL(); + if (GetExpireStatus()) { // gc enabled + expire_time = GetExpireTime(*ttl); + expire_cnt = ttl->lat_ttl; + } + uint32_t real_idx = index_def->GetInnerPos(); + auto ts_col = index_def->GetTsColumn(); + if (ts_col) { + // if secondary, use iot iterator + if (index_def->IsSecondaryIndex()) { + // get clustered index iter for secondary index + auto handler = catalog_->GetTable(GetDB(), GetName()); + if (!handler) { + LOG(WARNING) << "no TableHandler for " << GetDB() << "." << GetName(); + return nullptr; + } + auto tablet_table_handler = std::dynamic_pointer_cast(handler); + if (!tablet_table_handler) { + LOG(WARNING) << "convert TabletTableHandler failed for " << GetDB() << "." << GetName(); + return nullptr; + } + LOG(INFO) << "create iot traverse iterator for traverse"; + // TODO(hw): iter may be invalid if catalog updated + auto iter = new IOTTraverseIterator( + GetSegments(real_idx), GetSegCnt(), ttl->ttl_type, expire_time, expire_cnt, ts_col->GetId(), + GetCompressType(), + std::move(tablet_table_handler->GetWindowIterator(table_index_.GetIndex(0)->GetName()))); + return iter; + } + DLOG(INFO) << "create memtable traverse iterator for traverse"; + return new MemTableTraverseIterator(GetSegments(real_idx), GetSegCnt(), ttl->ttl_type, expire_time, expire_cnt, + ts_col->GetId(), GetCompressType()); + } + DLOG(INFO) << "index ts col is null, reate no-ts iterator"; + return new MemTableTraverseIterator(GetSegments(real_idx), GetSegCnt(), ttl->ttl_type, expire_time, expire_cnt, 0, + GetCompressType()); +} + +::hybridse::vm::WindowIterator* IndexOrganizedTable::NewWindowIterator(uint32_t index) { + std::shared_ptr index_def = table_index_.GetIndex(index); + if (!index_def || !index_def->IsReady()) { + LOG(WARNING) << "index id " << index << " not found. tid " << id_ << " pid " << pid_; + return nullptr; + } + LOG(INFO) << "new window iter for index " << index << " name " << index_def->GetName(); + uint64_t expire_time = 0; + uint64_t expire_cnt = 0; + auto ttl = index_def->GetTTL(); + if (GetExpireStatus()) { + expire_time = GetExpireTime(*ttl); + expire_cnt = ttl->lat_ttl; + } + uint32_t real_idx = index_def->GetInnerPos(); + auto ts_col = index_def->GetTsColumn(); + uint32_t ts_idx = 0; + if (ts_col) { + ts_idx = ts_col->GetId(); + } + DLOG(INFO) << "ts is null? " << ts_col << ", ts_idx " << ts_idx; + // if secondary, use iot iterator + if (index_def->IsSecondaryIndex()) { + // get clustered index iter for secondary index + auto handler = catalog_->GetTable(GetDB(), GetName()); + if (!handler) { + LOG(WARNING) << "no TableHandler for " << GetDB() << "." << GetName(); + return nullptr; + } + auto tablet_table_handler = std::dynamic_pointer_cast(handler); + if (!tablet_table_handler) { + LOG(WARNING) << "convert TabletTableHandler failed for " << GetDB() << "." << GetName(); + return nullptr; + } + LOG(INFO) << "create iot key traverse iterator for window"; + // TODO(hw): iter may be invalid if catalog updated + auto iter = + new IOTKeyIterator(GetSegments(real_idx), GetSegCnt(), ttl->ttl_type, expire_time, expire_cnt, ts_idx, + GetCompressType(), tablet_table_handler, table_index_.GetIndex(0)->GetName()); + return iter; + } + return new MemTableKeyIterator(GetSegments(real_idx), GetSegCnt(), ttl->ttl_type, expire_time, expire_cnt, ts_idx, + GetCompressType()); +} + +bool IndexOrganizedTable::Init() { + if (!InitMeta()) { + LOG(WARNING) << "init meta failed. tid " << id_ << " pid " << pid_; + return false; + } + // IOTSegment should know which is the cidx, sidx and covering idx are both duplicate(even the values are different) + auto inner_indexs = table_index_.GetAllInnerIndex(); + for (uint32_t i = 0; i < inner_indexs->size(); i++) { + const std::vector& ts_vec = inner_indexs->at(i)->GetTsIdx(); + uint32_t cur_key_entry_max_height = KeyEntryMaxHeight(inner_indexs->at(i)); + + Segment** seg_arr = new Segment*[seg_cnt_]; + DLOG_ASSERT(!ts_vec.empty()) << "must have ts, include auto gen ts"; + if (!ts_vec.empty()) { + for (uint32_t j = 0; j < seg_cnt_; j++) { + // let segment know whether it is cidx + seg_arr[j] = new IOTSegment(cur_key_entry_max_height, ts_vec, inner_indexs->at(i)->GetTsIdxType()); + PDLOG(INFO, "init %u, %u segment. height %u, ts col num %u. tid %u pid %u", i, j, + cur_key_entry_max_height, ts_vec.size(), id_, pid_); + } + } else { + // unavaildable + for (uint32_t j = 0; j < seg_cnt_; j++) { + seg_arr[j] = new IOTSegment(cur_key_entry_max_height); + PDLOG(INFO, "init %u, %u segment. height %u tid %u pid %u", i, j, cur_key_entry_max_height, id_, pid_); + } + } + segments_[i] = seg_arr; + key_entry_max_height_ = cur_key_entry_max_height; + } + LOG(INFO) << "init iot table name " << name_ << ", id " << id_ << ", pid " << pid_ << ", seg_cnt " << seg_cnt_; + return true; +} + +bool IndexOrganizedTable::Put(const std::string& pk, uint64_t time, const char* data, uint32_t size) { + uint32_t seg_idx = SegIdx(pk); + Segment* segment = GetSegment(0, seg_idx); + if (segment == nullptr) { + return false; + } + Slice spk(pk); + segment->Put(spk, time, data, size); + record_byte_size_.fetch_add(GetRecordSize(size)); + return true; +} + +absl::Status IndexOrganizedTable::Put(uint64_t time, const std::string& value, const Dimensions& dimensions, + bool put_if_absent) { + if (dimensions.empty()) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": empty dimension")); + } + // inner index pos: -1 means invalid, so it's positive in inner_index_key_map + std::map inner_index_key_map; + std::pair cidx_inner_key_pair{-1, ""}; + std::vector secondary_inners; + for (auto iter = dimensions.begin(); iter != dimensions.end(); iter++) { + int32_t inner_pos = table_index_.GetInnerIndexPos(iter->idx()); + if (inner_pos < 0) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid dimension idx ", iter->idx())); + } + if (iter->idx() == 0) { + cidx_inner_key_pair = {inner_pos, iter->key()}; + } + inner_index_key_map.emplace(inner_pos, iter->key()); + } + + const int8_t* data = reinterpret_cast(value.data()); + std::string uncompress_data; + uint32_t data_length = value.length(); + if (GetCompressType() == openmldb::type::kSnappy) { + snappy::Uncompress(value.data(), value.size(), &uncompress_data); + data = reinterpret_cast(uncompress_data.data()); + data_length = uncompress_data.length(); + } + if (data_length < codec::HEADER_LENGTH) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid value")); + } + uint8_t version = codec::RowView::GetSchemaVersion(data); + auto decoder = GetVersionDecoder(version); + if (decoder == nullptr) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid schema version ", version)); + } + std::optional clustered_tsv; + std::map> ts_value_map; + // we need two ref cnt + // 1. clustered and covering: put row -> DataBlock(i) + // 2. secondary: put pkeys+pts -> DataBlock(j) + uint32_t real_ref_cnt = 0, secondary_ref_cnt = 0; + // cidx_inner_key_pair can get the clustered index + for (const auto& kv : inner_index_key_map) { + auto inner_index = table_index_.GetInnerIndex(kv.first); + if (!inner_index) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid inner index pos ", kv.first)); + } + std::map ts_map; + for (const auto& index_def : inner_index->GetIndex()) { + if (!index_def->IsReady()) { + continue; + } + auto ts_col = index_def->GetTsColumn(); + if (ts_col) { + int64_t ts = 0; + if (ts_col->IsAutoGenTs()) { + // clustered index still use current time to ttl and delete iter, we'll check time series size if ts + // is auto gen + ts = time; + } else if (decoder->GetInteger(data, ts_col->GetId(), ts_col->GetType(), &ts) != 0) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": get ts failed")); + } + if (ts < 0) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": ts is negative ", ts)); + } + // TODO(hw): why uint32_t to int32_t? + ts_map.emplace(ts_col->GetId(), ts); + + if (index_def->IsSecondaryIndex()) { + secondary_ref_cnt++; + } else { + real_ref_cnt++; + } + if (index_def->IsClusteredIndex()) { + clustered_tsv = ts; + } + } + } + if (!ts_map.empty()) { + ts_value_map.emplace(kv.first, std::move(ts_map)); + } + } + if (ts_value_map.empty()) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": empty ts value map")); + } + // it's ok to have no clustered/covering put or no secondary put, put will be applyed on other pid + // but if no clustered/covering put and no secondary put, it's invalid, check it in put-loop + DataBlock* cblock = nullptr; + DataBlock* sblock = nullptr; + if (real_ref_cnt > 0) { + cblock = new DataBlock(real_ref_cnt, value.c_str(), value.length()); // hard copy + } + if (secondary_ref_cnt > 0) { + // dimensions may not contain cidx, but we need cidx pkeys+pts for secondary index + // if contains, just use the key; if not, extract from value + if (cidx_inner_key_pair.first == -1) { + DLOG(INFO) << "cidx not in dimensions, extract from value"; + auto cidx = table_index_.GetIndex(0); + auto hint = base::MakePkeysHint(table_meta_->column_desc(), table_meta_->column_key(0)); + if (hint.empty()) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": cidx pkeys hint empty")); + } + cidx_inner_key_pair.second = + base::ExtractPkeys(table_meta_->column_key(0), (int8_t*)value.c_str(), *decoder, hint); + if (cidx_inner_key_pair.second.empty()) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": cidx pkeys+pts extract failed")); + } + DLOG_ASSERT(!clustered_tsv) << "clustered ts should not be set too"; + auto ts_col = cidx->GetTsColumn(); + if (!ts_col) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ":no ts column in cidx")); + } + int64_t ts = 0; + if (ts_col->IsAutoGenTs()) { + // clustered index still use current time to ttl and delete iter, we'll check time series size if ts is + // auto gen + ts = time; + } else if (decoder->GetInteger(data, ts_col->GetId(), ts_col->GetType(), &ts) != 0) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": get ts failed")); + } + if (ts < 0) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": ts is negative ", ts)); + } + clustered_tsv = ts; + } + auto pkeys_pts = PackPkeysAndPts(cidx_inner_key_pair.second, clustered_tsv.value()); + if (GetCompressType() == type::kSnappy) { // sidx iterator will uncompress when getting pkeys+pts + std::string val; + ::snappy::Compress(pkeys_pts.c_str(), pkeys_pts.length(), &val); + sblock = new DataBlock(secondary_ref_cnt, val.c_str(), val.length()); + } else { + sblock = new DataBlock(secondary_ref_cnt, pkeys_pts.c_str(), pkeys_pts.length()); // hard copy + } + } + DLOG(INFO) << "put iot table " << id_ << "." << pid_ << " key+ts " << cidx_inner_key_pair.second << " - " + << (clustered_tsv ? std::to_string(clustered_tsv.value()) : "-1") << ", real ref cnt " << real_ref_cnt + << " secondary ref cnt " << secondary_ref_cnt; + + for (const auto& kv : inner_index_key_map) { + auto iter = ts_value_map.find(kv.first); + if (iter == ts_value_map.end()) { + continue; + } + uint32_t seg_idx = SegIdx(kv.second.ToString()); + auto iot_segment = dynamic_cast(GetSegment(kv.first, seg_idx)); + // TODO(hw): put if absent unsupportted + if (put_if_absent) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": iot put if absent is not supported")); + } + // clustered segment should be dedup and update will trigger all index update(impl in cli router) + if (!iot_segment->Put(kv.second, iter->second, cblock, sblock, false)) { + // even no put_if_absent, return false if exists or wrong + return absl::AlreadyExistsError("data exists or wrong"); + } + } + // cblock and sblock both will sub record_byte_size_ when delete, so add them all + // TODO(hw): test for cal + if (real_ref_cnt > 0) { + record_byte_size_.fetch_add(GetRecordSize(cblock->size)); + } + if (secondary_ref_cnt > 0) { + record_byte_size_.fetch_add(GetRecordSize(sblock->size)); + } + + return absl::OkStatus(); +} + +absl::Status IndexOrganizedTable::CheckDataExists(uint64_t tsv, const Dimensions& dimensions) { + // get cidx dim + if (dimensions.empty()) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": empty dimension")); + } + // inner index pos: -1 means invalid, so it's positive in inner_index_key_map + std::pair cidx_inner_key_pair{-1, ""}; + for (auto iter = dimensions.begin(); iter != dimensions.end(); iter++) { + int32_t inner_pos = table_index_.GetInnerIndexPos(iter->idx()); + if (inner_pos < 0) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid dimension idx ", iter->idx())); + } + if (iter->idx() == 0) { + cidx_inner_key_pair = {inner_pos, iter->key()}; + } + } + if (cidx_inner_key_pair.first == -1) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": cidx not found")); + } + auto cidx = table_index_.GetIndex(0); + if (!cidx->IsReady()) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": cidx is not ready")); + } + auto ts_col = cidx->GetTsColumn(); + if (!ts_col) { + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": no ts column")); + } + DLOG(INFO) << "check iot table " << id_ << "." << pid_ << " key+ts " << cidx_inner_key_pair.second << " - " << tsv + << ", on index " << cidx->GetName() << " ts col " << ts_col->GetId(); + + uint32_t seg_idx = SegIdx(cidx_inner_key_pair.second); + auto iot_segment = dynamic_cast(GetSegment(cidx_inner_key_pair.first, seg_idx)); + // ts id -> ts value + return iot_segment->CheckKeyExists(cidx_inner_key_pair.second, {{ts_col->GetId(), tsv}}); +} + +// index gc should try to do ExecuteGc for each waiting segment, but if some segments are gc before, we should release +// them so it will be a little complex +// should run under lock +absl::Status IndexOrganizedTable::ClusteredIndexGCByDelete(const std::shared_ptr& router) { + auto cur_index = table_index_.GetIndex(0); + if (!cur_index) { + return absl::FailedPreconditionError( + absl::StrCat("cidx def is null for ", id_, ".", pid_)); // why index is null? + } + if (!cur_index->IsClusteredIndex()) { + return absl::InternalError(absl::StrCat("cidx is not clustered for ", id_, ".", pid_)); // immpossible + } + if (!cur_index->IsReady()) { + return absl::FailedPreconditionError( + absl::StrCat("cidx is not ready for ", id_, ".", pid_, ", status ", cur_index->GetStatus())); + } + auto& ts_col = cur_index->GetTsColumn(); + // sometimes index def is valid, but ts_col is nullptr? protect it + if (!ts_col) { + return absl::FailedPreconditionError( + absl::StrCat("no ts col of cidx for ", id_, ".", pid_)); // current time ts can be get too + } + // clustered index grep all entries or less to delete(it's simpler to run delete sql) + // not the real gc, so don't change index status + auto i = cur_index->GetId(); + std::map ttl_st_map; + // only set cidx + ttl_st_map.emplace(ts_col->GetId(), *cur_index->GetTTL()); + GCEntryInfo info; // not thread safe + for (uint32_t j = 0; j < seg_cnt_; j++) { + uint64_t seg_gc_time = ::baidu::common::timer::get_micros() / 1000; + Segment* segment = segments_[i][j]; + auto iot_segment = dynamic_cast(segment); + iot_segment->GrepGCEntry(ttl_st_map, &info); + seg_gc_time = ::baidu::common::timer::get_micros() / 1000 - seg_gc_time; + PDLOG(INFO, "grep cidx segment[%u][%u] gc entries done consumed %lu for table %s tid %u pid %u", i, j, + seg_gc_time, name_.c_str(), id_, pid_); + } + // delete entries by sql + if (info.Size() > 0) { + LOG(INFO) << "delete cidx " << info.Size() << " entries by sql"; + auto meta = GetTableMeta(); + auto cols = meta->column_desc(); // copy + codec::RowView row_view(cols); + auto hint = base::MakePkeysHint(cols, meta->column_key(0)); + if (hint.empty()) { + return absl::InternalError("make pkeys hint failed"); + } + for (size_t i = 0; i < info.Size(); i++) { + auto& keys_ts = info.GetEntries()[i]; + auto values = keys_ts.second; // get pkeys from values + auto ts = keys_ts.first; + auto sql = + base::MakeDeleteSQL(GetDB(), GetName(), meta->column_key(0), (int8_t*)values->data, ts, row_view, hint); + // TODO(hw): if delete failed, we can't revert. And if sidx skeys+sts doesn't change, no need to delete and + // then insert + if (sql.empty()) { + return absl::InternalError("make delete sql failed"); + } + // delete will move node to node cache, it's alive, so GCEntryInfo can unref it + hybridse::sdk::Status status; + router->ExecuteSQL(sql, &status); + if (!status.IsOK()) { + return absl::InternalError("execute sql failed " + status.ToString()); + } + } + } + + return absl::OkStatus(); +} + +// TODO(hw): don't refactor with MemTable, make MemTable stable +void IndexOrganizedTable::SchedGCByDelete(const std::shared_ptr& router) { + std::lock_guard lock(gc_lock_); + uint64_t consumed = ::baidu::common::timer::get_micros(); + if (!enable_gc_.load(std::memory_order_relaxed)) { + LOG(INFO) << "iot table " << name_ << "[" << id_ << "." << pid_ << "] gc disabled"; + return; + } + LOG(INFO) << "iot table " << name_ << "[" << id_ << "." << pid_ << "] start making gc"; + // gc cidx first, it'll delete on all indexes + auto st = ClusteredIndexGCByDelete(router); + if (!st.ok()) { + LOG(WARNING) << "cidx gc by delete error: " << st.ToString(); + } + // TODO how to check the record byte size? + uint64_t gc_idx_cnt = 0; + uint64_t gc_record_byte_size = 0; + auto inner_indexs = table_index_.GetAllInnerIndex(); + for (uint32_t i = 0; i < inner_indexs->size(); i++) { + const std::vector>& real_index = inner_indexs->at(i)->GetIndex(); + std::map ttl_st_map; + bool need_gc = true; + size_t deleted_num = 0; + std::vector deleting_pos; + for (size_t pos = 0; pos < real_index.size(); pos++) { + auto cur_index = real_index[pos]; + auto ts_col = cur_index->GetTsColumn(); + if (ts_col) { + ttl_st_map.emplace(ts_col->GetId(), *(cur_index->GetTTL())); + } + if (cur_index->GetStatus() == IndexStatus::kWaiting) { + cur_index->SetStatus(IndexStatus::kDeleting); + need_gc = false; + } else if (cur_index->GetStatus() == IndexStatus::kDeleting) { + deleting_pos.push_back(pos); + } else if (cur_index->GetStatus() == IndexStatus::kDeleted) { + deleted_num++; + } + } + if (!deleting_pos.empty()) { + if (segments_[i] != nullptr) { + for (uint32_t k = 0; k < seg_cnt_; k++) { + if (segments_[i][k] != nullptr) { + StatisticsInfo statistics_info(segments_[i][k]->GetTsCnt()); + if (real_index.size() == 1 || deleting_pos.size() + deleted_num == real_index.size()) { + segments_[i][k]->ReleaseAndCount(&statistics_info); + } else { + segments_[i][k]->ReleaseAndCount(deleting_pos, &statistics_info); + } + gc_idx_cnt += statistics_info.GetTotalCnt(); + gc_record_byte_size += statistics_info.record_byte_size; + LOG(INFO) << "release segment[" << i << "][" << k << "] done, gc record cnt " + << statistics_info.GetTotalCnt() << ", gc record byte size " + << statistics_info.record_byte_size; + } + } + } + for (auto pos : deleting_pos) { + real_index[pos]->SetStatus(IndexStatus::kDeleted); + } + deleted_num += deleting_pos.size(); + } + if (!need_gc) { + continue; + } + // skip cidx gc in segment, gcfreelist shouldn't be skiped, so we don't change the condition + if (deleted_num == real_index.size() || ttl_st_map.empty()) { + continue; + } + for (uint32_t j = 0; j < seg_cnt_; j++) { + uint64_t seg_gc_time = ::baidu::common::timer::get_micros() / 1000; + auto segment = dynamic_cast(segments_[i][j]); + StatisticsInfo statistics_info(segment->GetTsCnt()); + segment->IncrGcVersion(); + segment->GcFreeList(&statistics_info); + // don't gc in cidx, it's not a good way to impl, refactor later + segment->ExecuteGc(ttl_st_map, &statistics_info, segment->ClusteredTs()); + gc_idx_cnt += statistics_info.GetTotalCnt(); + gc_record_byte_size += statistics_info.record_byte_size; + seg_gc_time = ::baidu::common::timer::get_micros() / 1000 - seg_gc_time; + VLOG(1) << "gc segment[" << i << "][" << j << "] done, consumed time " << seg_gc_time << "ms for table " + << name_ << "[" << id_ << "." << pid_ << "], statistics_info: [" << statistics_info.DebugString() + << "]"; + } + } + consumed = ::baidu::common::timer::get_micros() - consumed; + LOG(INFO) << "record byte size before gc: " << record_byte_size_.load() + << ", gc record byte size: " << gc_record_byte_size << ", gc idx cnt: " << gc_idx_cnt + << ", gc consumed: " << consumed / 1000 << " ms"; + record_byte_size_.fetch_sub(gc_record_byte_size, std::memory_order_relaxed); + UpdateTTL(); + LOG(INFO) << "update ttl done"; +} + +bool IndexOrganizedTable::AddIndexToTable(const std::shared_ptr& index_def) { + std::vector ts_vec = {index_def->GetTsColumn()->GetId()}; + uint32_t inner_id = index_def->GetInnerPos(); + Segment** seg_arr = new Segment*[seg_cnt_]; + for (uint32_t j = 0; j < seg_cnt_; j++) { + seg_arr[j] = new IOTSegment(FLAGS_absolute_default_skiplist_height, ts_vec, {index_def->GetIndexType()}); + LOG(INFO) << "init iot segment inner_ts" << inner_id << "." << j << " for table " << name_ << "[" << id_ << "." + << pid_ << "], height " << FLAGS_absolute_default_skiplist_height << ", ts col num " << ts_vec.size() + << ", type " << IndexType_Name(index_def->GetIndexType()); + } + segments_[inner_id] = seg_arr; + return true; +} + +} // namespace openmldb::storage diff --git a/src/storage/index_organized_table.h b/src/storage/index_organized_table.h new file mode 100644 index 00000000000..014e1a56a0a --- /dev/null +++ b/src/storage/index_organized_table.h @@ -0,0 +1,68 @@ +/* + * Copyright 2021 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SRC_STORAGE_INDEX_ORGANIZED_TABLE_H_ +#define SRC_STORAGE_INDEX_ORGANIZED_TABLE_H_ + +#include + +#include "catalog/tablet_catalog.h" +#include "storage/mem_table.h" + +namespace openmldb::storage { + +class IndexOrganizedTable : public MemTable { + public: + IndexOrganizedTable(const ::openmldb::api::TableMeta& table_meta, std::shared_ptr catalog) + : MemTable(table_meta), catalog_(catalog) {} + + TableIterator* NewIterator(uint32_t index, const std::string& pk, Ticket& ticket) override; + + TraverseIterator* NewTraverseIterator(uint32_t index) override; + + ::hybridse::vm::WindowIterator* NewWindowIterator(uint32_t index) override; + + bool Init() override; + + bool Put(const std::string& pk, uint64_t time, const char* data, uint32_t size) override; + + absl::Status Put(uint64_t time, const std::string& value, const Dimensions& dimensions, + bool put_if_absent) override; + + absl::Status CheckDataExists(uint64_t tsv, const Dimensions& dimensions); + + // TODO(hw): iot bulk load unsupported + bool GetBulkLoadInfo(::openmldb::api::BulkLoadInfoResponse* response) { return false; } + bool BulkLoad(const std::vector& data_blocks, + const ::google::protobuf::RepeatedPtrField<::openmldb::api::BulkLoadIndex>& indexes) { + return false; + } + bool AddIndexToTable(const std::shared_ptr& index_def) override; + + void SchedGCByDelete(const std::shared_ptr& router); + + private: + absl::Status ClusteredIndexGCByDelete(const std::shared_ptr& router); + + private: + // to get current distribute iterator + std::shared_ptr catalog_; + + std::mutex gc_lock_; +}; + +} // namespace openmldb::storage +#endif diff --git a/src/storage/iot_segment.cc b/src/storage/iot_segment.cc new file mode 100644 index 00000000000..89a19e4838f --- /dev/null +++ b/src/storage/iot_segment.cc @@ -0,0 +1,412 @@ +/* + * Copyright 2021 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "storage/iot_segment.h" + +#include "iot_segment.h" + +namespace openmldb::storage { +base::Slice RowToSlice(const ::hybridse::codec::Row& row) { + butil::IOBuf buf; + size_t size; + if (codec::EncodeRpcRow(row, &buf, &size)) { + auto r = new char[buf.size()]; + buf.copy_to(r); // TODO(hw): don't copy, move it to slice + // slice own the new r + return {r, size, true}; + } + LOG(WARNING) << "convert row to slice failed"; + return {}; +} + +std::string PackPkeysAndPts(const std::string& pkeys, uint64_t pts) { + std::string buf; + uint32_t pkeys_size = pkeys.size(); + buf.append(reinterpret_cast(&pkeys_size), sizeof(uint32_t)); + buf.append(pkeys); + buf.append(reinterpret_cast(&pts), sizeof(uint64_t)); + return buf; +} + +bool UnpackPkeysAndPts(const std::string& block, std::string* pkeys, uint64_t* pts) { + DLOG_ASSERT(block.size() >= sizeof(uint32_t) + sizeof(uint64_t)) << "block size is " << block.size(); + uint32_t offset = 0; + uint32_t pkeys_size = *reinterpret_cast(block.data() + offset); + offset += sizeof(uint32_t); + pkeys->assign(block.data() + offset, pkeys_size); + offset += pkeys_size; + *pts = *reinterpret_cast(block.data() + offset); + DLOG_ASSERT(offset + sizeof(uint64_t) == block.size()) + << "offset is " << offset << " block size is " << block.size(); + return true; +} + +// put_if_absent unsupported, iot table will reject put, no need to check here, just ignore +bool IOTSegment::PutUnlock(const Slice& key, uint64_t time, DataBlock* row, bool put_if_absent, bool auto_gen_ts) { + void* entry = nullptr; + uint32_t byte_size = 0; + // one key just one entry + int ret = entries_->Get(key, entry); + if (ret < 0 || entry == nullptr) { + char* pk = new char[key.size()]; + memcpy(pk, key.data(), key.size()); + // need to delete memory when free node + Slice skey(pk, key.size()); + entry = reinterpret_cast(new KeyEntry(key_entry_max_height_)); + uint8_t height = entries_->Insert(skey, entry); + byte_size += GetRecordPkIdxSize(height, key.size(), key_entry_max_height_); + pk_cnt_.fetch_add(1, std::memory_order_relaxed); + // no need to check if absent when first put + } else if (IsClusteredTs(ts_idx_map_.begin()->first)) { + // if cidx and key match, check ts -> insert or update + if (auto_gen_ts) { + // cidx(keys) has just one entry for one keys, so if keys exists, needs delete + LOG_IF(ERROR, reinterpret_cast(entry)->entries.GetSize() > 1) + << "cidx keys has more than one entry, " << reinterpret_cast(entry)->entries.GetSize(); + // TODO(hw): client will delete old row, so if pkeys exists when auto ts, fail it + return false; + } else { + // cidx(keys+ts) check if ts match + if (ListContains(reinterpret_cast(entry), time, row, false)) { + LOG(WARNING) << "key " << key.ToString() << " ts " << time << " exists in cidx"; + return false; + } + } + } + + idx_cnt_vec_[0]->fetch_add(1, std::memory_order_relaxed); + uint8_t height = reinterpret_cast(entry)->entries.Insert(time, row); + reinterpret_cast(entry)->count_.fetch_add(1, std::memory_order_relaxed); + byte_size += GetRecordTsIdxSize(height); + idx_byte_size_.fetch_add(byte_size, std::memory_order_relaxed); + DLOG(INFO) << "idx_byte_size_ " << idx_byte_size_ << " after add " << byte_size; + return true; +} + +bool IOTSegment::Put(const Slice& key, const std::map& ts_map, DataBlock* cblock, DataBlock* sblock, + bool put_if_absent) { + if (ts_map.empty()) { + return false; + } + if (ts_cnt_ == 1) { + bool ret = false; + if (auto pos = ts_map.find(ts_idx_map_.begin()->first); pos != ts_map.end()) { + // TODO(hw): why ts_map key is int32_t, default ts is uint32_t? + ret = Segment::Put(key, pos->second, + (index_types_[ts_idx_map_.begin()->second] == common::kSecondary ? sblock : cblock), + false, pos->first == DEFAULT_TS_COL_ID); + } + return ret; + } + void* entry_arr = nullptr; + std::lock_guard lock(mu_); + for (const auto& kv : ts_map) { + uint32_t byte_size = 0; + auto pos = ts_idx_map_.find(kv.first); + if (pos == ts_idx_map_.end()) { + continue; + } + if (entry_arr == nullptr) { + int ret = entries_->Get(key, entry_arr); + if (ret < 0 || entry_arr == nullptr) { + char* pk = new char[key.size()]; + memcpy(pk, key.data(), key.size()); + Slice skey(pk, key.size()); + KeyEntry** entry_arr_tmp = new KeyEntry*[ts_cnt_]; + for (uint32_t i = 0; i < ts_cnt_; i++) { + entry_arr_tmp[i] = new KeyEntry(key_entry_max_height_); + } + entry_arr = reinterpret_cast(entry_arr_tmp); + uint8_t height = entries_->Insert(skey, entry_arr); + byte_size += GetRecordPkMultiIdxSize(height, key.size(), key_entry_max_height_, ts_cnt_); + pk_cnt_.fetch_add(1, std::memory_order_relaxed); + } + } + auto entry = reinterpret_cast(entry_arr)[pos->second]; + auto auto_gen_ts = (pos->first == DEFAULT_TS_COL_ID); + auto pblock = (index_types_[pos->second] == common::kSecondary ? sblock : cblock); + if (IsClusteredTs(pos->first)) { + // if cidx and key match, check ts -> insert or update + if (auto_gen_ts) { + // cidx(keys) has just one entry for one keys, so if keys exists, needs delete + LOG_IF(ERROR, reinterpret_cast(entry)->entries.GetSize() > 1) + << "cidx keys has more than one entry, " << reinterpret_cast(entry)->entries.GetSize(); + // TODO(hw): client will delete old row, so if pkeys exists when auto ts, fail it + if (reinterpret_cast(entry)->entries.GetSize() > 0) { + LOG(WARNING) << "key " << key.ToString() << " exists in cidx"; + return false; + } + } else { + // cidx(keys+ts) check if ts match + if (ListContains(reinterpret_cast(entry), kv.second, pblock, false)) { + LOG(WARNING) << "key " << key.ToString() << " ts " << kv.second << " exists in cidx"; + return false; + } + } + } + uint8_t height = entry->entries.Insert(kv.second, pblock); + entry->count_.fetch_add(1, std::memory_order_relaxed); + byte_size += GetRecordTsIdxSize(height); + idx_byte_size_.fetch_add(byte_size, std::memory_order_relaxed); + DLOG(INFO) << "idx_byte_size_ " << idx_byte_size_ << " after add " << byte_size; + idx_cnt_vec_[pos->second]->fetch_add(1, std::memory_order_relaxed); + } + return true; +} + +absl::Status IOTSegment::CheckKeyExists(const Slice& key, const std::map& ts_map) { + // check lock + void* entry_arr = nullptr; + std::lock_guard lock(mu_); // need shrink? + int ret = entries_->Get(key, entry_arr); + if (ret < 0 || entry_arr == nullptr) { + return absl::NotFoundError("key not found"); + } + if (ts_map.size() != 1) { + return absl::InvalidArgumentError("ts map size is not 1"); + } + auto idx_ts = ts_map.begin(); + auto pos = ts_idx_map_.find(idx_ts->first); + if (pos == ts_idx_map_.end()) { + return absl::InvalidArgumentError("ts not found"); + } + // be careful, ts id in arg maybe negative cuz it's int32, but id in member is uint32 + if (!IsClusteredTs(idx_ts->first)) { + LOG(WARNING) << "idx_ts->first " << idx_ts->first << " is not clustered ts " + << (clustered_ts_id_.has_value() ? std::to_string(clustered_ts_id_.value()) : "no"); + return absl::InvalidArgumentError("ts is not clustered"); + } + KeyEntry* entry = nullptr; + if (ts_cnt_ == 1) { + LOG_IF(ERROR, pos->second != 0) << "when ts cnt == 1, pos second is " << pos->second; + entry = reinterpret_cast(entry_arr); + } else { + entry = reinterpret_cast(entry_arr)[pos->second]; + } + + if (entry == nullptr) { + return absl::NotFoundError("ts entry not found"); + } + auto auto_gen_ts = (idx_ts->first == DEFAULT_TS_COL_ID); + if (auto_gen_ts) { + // cidx(keys) has just one entry for one keys, so if keys exists, needs delete + DLOG_ASSERT(reinterpret_cast(entry)->entries.GetSize() == 1) << "cidx keys has more than one entry"; + if (reinterpret_cast(entry)->entries.GetSize() > 0) { + return absl::AlreadyExistsError("key exists: " + key.ToString()); + } + } else { + // don't use listcontains, we don't need to check value, just check if time exists + storage::DataBlock* v = nullptr; + if (entry->entries.Get(idx_ts->second, v) == 0) { + return absl::AlreadyExistsError(absl::StrCat("key+ts exists: ", key.ToString(), ", ts ", idx_ts->second)); + } + } + + return absl::NotFoundError("ts not found"); +} +// TODO(hw): when add lock? ref segment, don't lock iter +void IOTSegment::GrepGCEntry(const std::map& ttl_st_map, GCEntryInfo* gc_entry_info) { + if (ttl_st_map.empty()) { + DLOG(INFO) << "ttl map is empty, skip gc"; + return; + } + + bool need_gc = false; + for (const auto& kv : ttl_st_map) { + if (ts_idx_map_.find(kv.first) == ts_idx_map_.end()) { + LOG(WARNING) << "ts idx " << kv.first << " not found"; + return; + } + if (kv.second.NeedGc()) { + need_gc = true; + } + } + if (!need_gc) { + DLOG(INFO) << "no need gc, skip gc"; + return; + } + GrepGCAllType(ttl_st_map, gc_entry_info); +} + +void GrepGC4Abs(KeyEntry* entry, const Slice& key, const TTLSt& ttl, uint64_t cur_time, uint64_t ttl_offset, + GCEntryInfo* gc_entry_info) { + if (ttl.abs_ttl == 0) { + return; // never expire + } + uint64_t expire_time = cur_time - ttl_offset - ttl.abs_ttl; + std::unique_ptr iter(entry->entries.NewIterator()); + iter->Seek(expire_time); + // delete (expire, last] + while (iter->Valid()) { + if (iter->GetKey() > expire_time) { + break; + } + // expire_time has offset, so we don't need to check if equal + // if (iter->GetKey() == expire_time) { + // continue; // save ==, don't gc + // } + gc_entry_info->AddEntry(key, iter->GetKey(), iter->GetValue()); + if (gc_entry_info->Full()) { + LOG(INFO) << "gc entry info full, stop gc grep"; + return; + } + iter->Next(); + } +} + +void GrepGC4Lat(KeyEntry* entry, const Slice& key, const TTLSt& ttl, GCEntryInfo* gc_entry_info) { + auto keep_cnt = ttl.lat_ttl; + if (keep_cnt == 0) { + return; // never exipre + } + + std::unique_ptr iter(entry->entries.NewIterator()); + iter->SeekToFirst(); + while (iter->Valid()) { + if (keep_cnt > 0) { + keep_cnt--; + } else { + gc_entry_info->AddEntry(key, iter->GetKey(), iter->GetValue()); + } + if (gc_entry_info->Full()) { + LOG(INFO) << "gc entry info full, stop gc grep"; + return; + } + iter->Next(); + } +} + +void GrepGC4AbsAndLat(KeyEntry* entry, const Slice& key, const TTLSt& ttl, uint64_t cur_time, uint64_t ttl_offset, + GCEntryInfo* gc_entry_info) { + if (ttl.abs_ttl == 0 || ttl.lat_ttl == 0) { + return; // never exipre + } + // keep both + uint64_t expire_time = cur_time - ttl_offset - ttl.abs_ttl; + auto keep_cnt = ttl.lat_ttl; + std::unique_ptr iter(entry->entries.NewIterator()); + iter->SeekToFirst(); + // if > lat cnt and < expire, delete + while (iter->Valid()) { + if (keep_cnt > 0) { + keep_cnt--; + } else if (iter->GetKey() < expire_time) { + gc_entry_info->AddEntry(key, iter->GetKey(), iter->GetValue()); + } + if (gc_entry_info->Full()) { + LOG(INFO) << "gc entry info full, stop gc grep"; + return; + } + iter->Next(); + } +} +void GrepGC4AbsOrLat(KeyEntry* entry, const Slice& key, const TTLSt& ttl, uint64_t cur_time, uint64_t ttl_offset, + GCEntryInfo* gc_entry_info) { + if (ttl.abs_ttl == 0 && ttl.lat_ttl == 0) { + return; + } + if (ttl.abs_ttl == 0) { + // == lat ttl + GrepGC4Lat(entry, key, ttl, gc_entry_info); + return; + } + if (ttl.lat_ttl == 0) { + GrepGC4Abs(entry, key, ttl, cur_time, ttl_offset, gc_entry_info); + return; + } + uint64_t expire_time = cur_time - ttl_offset - ttl.abs_ttl; + auto keep_cnt = ttl.lat_ttl; + std::unique_ptr iter(entry->entries.NewIterator()); + iter->SeekToFirst(); + // if > keep cnt or < expire time, delete + while (iter->Valid()) { + if (keep_cnt > 0) { + keep_cnt--; // safe + } else { + gc_entry_info->AddEntry(key, iter->GetKey(), iter->GetValue()); + iter->Next(); + continue; + } + if (iter->GetKey() < expire_time) { + gc_entry_info->AddEntry(key, iter->GetKey(), iter->GetValue()); + } + if (gc_entry_info->Full()) { + LOG(INFO) << "gc entry info full, stop gc grep"; + return; + } + iter->Next(); + } +} + +// actually only one ttl for cidx, clean up later +void IOTSegment::GrepGCAllType(const std::map& ttl_st_map, GCEntryInfo* gc_entry_info) { + uint64_t consumed = ::baidu::common::timer::get_micros(); + uint64_t cur_time = consumed / 1000; + std::unique_ptr it(entries_->NewIterator()); + it->SeekToFirst(); + while (it->Valid()) { + KeyEntry** entry_arr = reinterpret_cast(it->GetValue()); + Slice key = it->GetKey(); + it->Next(); + for (const auto& kv : ttl_st_map) { + DLOG(INFO) << "key " << key.ToString() << ", ts idx " << kv.first << ", ttl " << kv.second.ToString() + << ", ts_cnt_ " << ts_cnt_; + if (!kv.second.NeedGc()) { + continue; + } + auto pos = ts_idx_map_.find(kv.first); + if (pos == ts_idx_map_.end() || pos->second >= ts_cnt_) { + LOG(WARNING) << "gc ts idx " << kv.first << " not found"; + continue; + } + KeyEntry* entry = nullptr; + // time series :[(ts, row), ...], so get key means get ts + if (ts_cnt_ == 1) { + LOG_IF(DFATAL, pos->second != 0) << "when ts cnt == 1, pos second is " << pos->second; + entry = reinterpret_cast(entry_arr); + } else { + entry = entry_arr[pos->second]; + } + if (entry == nullptr) { + DLOG(DFATAL) << "entry is null, impossible"; + continue; + } + switch (kv.second.ttl_type) { + case ::openmldb::storage::TTLType::kAbsoluteTime: { + GrepGC4Abs(entry, key, kv.second, cur_time, ttl_offset_, gc_entry_info); + break; + } + case ::openmldb::storage::TTLType::kLatestTime: { + GrepGC4Lat(entry, key, kv.second, gc_entry_info); + break; + } + case ::openmldb::storage::TTLType::kAbsAndLat: { + GrepGC4AbsAndLat(entry, key, kv.second, cur_time, ttl_offset_, gc_entry_info); + break; + } + case ::openmldb::storage::TTLType::kAbsOrLat: { + GrepGC4AbsOrLat(entry, key, kv.second, cur_time, ttl_offset_, gc_entry_info); + break; + } + default: + return; + } + } + } + DLOG(INFO) << "[GC ts map] iot segment gc consumed " << (::baidu::common::timer::get_micros() - consumed) / 1000 + << "ms, gc entry size " << gc_entry_info->Size(); +} +} // namespace openmldb::storage diff --git a/src/storage/iot_segment.h b/src/storage/iot_segment.h new file mode 100644 index 00000000000..b610241f240 --- /dev/null +++ b/src/storage/iot_segment.h @@ -0,0 +1,298 @@ +/* + * Copyright 2021 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SRC_STORAGE_IOT_SEGMENT_H_ +#define SRC_STORAGE_IOT_SEGMENT_H_ + +#include "catalog/tablet_catalog.h" +#include "codec/row_codec.h" +#include "codec/row_iterator.h" +#include "codec/sql_rpc_row_codec.h" +#include "storage/mem_table_iterator.h" +#include "storage/segment.h" +#include "storage/table.h" // for storage::Schema + +DECLARE_uint32(cidx_gc_max_size); + +namespace openmldb::storage { + +base::Slice RowToSlice(const ::hybridse::codec::Row& row); + +// [pkeys_size, pkeys, pts_size, ts_id, tsv, ...] +std::string PackPkeysAndPts(const std::string& pkeys, uint64_t pts); +bool UnpackPkeysAndPts(const std::string& block, std::string* pkeys, uint64_t* pts); + +// secondary index iterator +// GetValue will lookup, and it may trigger rpc +class IOTIterator : public MemTableIterator { + public: + IOTIterator(TimeEntries::Iterator* it, type::CompressType compress_type, + std::unique_ptr<::hybridse::codec::WindowIterator> cidx_iter) + : MemTableIterator(it, compress_type), cidx_iter_(std::move(cidx_iter)) {} + virtual ~IOTIterator() {} + + openmldb::base::Slice GetValue() const override { + auto pkeys_pts = MemTableIterator::GetValue(); + std::string pkeys; + uint64_t ts; + if (!UnpackPkeysAndPts(pkeys_pts.ToString(), &pkeys, &ts)) { + LOG(WARNING) << "unpack pkeys and pts failed"; + return ""; + } + cidx_iter_->Seek(pkeys); + if (cidx_iter_->Valid()) { + // seek to ts + auto ts_iter = cidx_iter_->GetValue(); + ts_iter->Seek(ts); + if (ts_iter->Valid()) { + return RowToSlice(ts_iter->GetValue()); + } + } + // TODO(hw): Valid() to check row data? what if only one entry invalid? + return ""; + } + + private: + std::unique_ptr<::hybridse::codec::WindowIterator> cidx_iter_; +}; + +class IOTTraverseIterator : public MemTableTraverseIterator { + public: + IOTTraverseIterator(Segment** segments, uint32_t seg_cnt, ::openmldb::storage::TTLType ttl_type, + uint64_t expire_time, uint64_t expire_cnt, uint32_t ts_index, type::CompressType compress_type, + std::unique_ptr<::hybridse::codec::WindowIterator> cidx_iter) + : MemTableTraverseIterator(segments, seg_cnt, ttl_type, expire_time, expire_cnt, ts_index, compress_type), + cidx_iter_(std::move(cidx_iter)) {} + ~IOTTraverseIterator() override {} + + openmldb::base::Slice GetValue() const override { + auto pkeys_pts = MemTableTraverseIterator::GetValue(); + std::string pkeys; + uint64_t ts; + if (!UnpackPkeysAndPts(pkeys_pts.ToString(), &pkeys, &ts)) { + LOG(WARNING) << "unpack pkeys and pts failed"; + return ""; + } + // distribute cidx iter should seek to (key, ts) + DLOG(INFO) << "seek to " << pkeys << ", " << ts; + cidx_iter_->Seek(pkeys); + if (cidx_iter_->Valid()) { + // seek to ts + auto ts_iter_ = cidx_iter_->GetValue(); + ts_iter_->Seek(ts); + if (ts_iter_->Valid()) { + // TODO(hw): hard copy, or hold ts_iter to store value? IOTIterator should be the same. + DLOG(INFO) << "valid, " << ts_iter_->GetValue().ToString(); + return RowToSlice(ts_iter_->GetValue()); + } + } + LOG(WARNING) << "no suitable iter"; + return ""; // won't core, just no row for select? + } + + private: + std::unique_ptr<::hybridse::codec::WindowIterator> cidx_iter_; + std::unique_ptr ts_iter_; +}; + +class IOTWindowIterator : public MemTableWindowIterator { + public: + IOTWindowIterator(TimeEntries::Iterator* it, ::openmldb::storage::TTLType ttl_type, uint64_t expire_time, + uint64_t expire_cnt, type::CompressType compress_type, + std::unique_ptr<::hybridse::codec::WindowIterator> cidx_iter) + : MemTableWindowIterator(it, ttl_type, expire_time, expire_cnt, compress_type), + cidx_iter_(std::move(cidx_iter)) { + DLOG(INFO) << "create IOTWindowIterator"; + } + // for debug + void SetSchema(const codec::Schema& schema, const std::vector& pkeys_idx) { + pkeys_idx_ = pkeys_idx; + row_view_.reset(new codec::RowView(schema)); + } + const ::hybridse::codec::Row& GetValue() override { + auto pkeys_pts = MemTableWindowIterator::GetValue(); + if (pkeys_pts.empty()) { + LOG(WARNING) << "empty pkeys_pts for key " << GetKey(); + return dummy; + } + + // unpack the row and get pkeys+pts + // Row -> cols + std::string pkeys; + uint64_t ts; + if (!UnpackPkeysAndPts(pkeys_pts.ToString(), &pkeys, &ts)) { + LOG(WARNING) << "unpack pkeys and pts failed"; + return dummy; + } + // TODO(hw): what if no ts? it'll be 0 for temp + DLOG(INFO) << "pkeys=" << pkeys << ", ts=" << ts; + cidx_iter_->Seek(pkeys); + if (cidx_iter_->Valid()) { + // seek to ts + DLOG(INFO) << "seek to ts " << ts; + // hold the row iterator to avoid invalidation + cidx_ts_iter_ = std::move(cidx_iter_->GetValue()); + cidx_ts_iter_->Seek(ts); + // must be the same keys+ts + if (cidx_ts_iter_->Valid()) { + // DLOG(INFO) << "valid, is the same value? " << GetKeys(cidx_ts_iter_->GetValue()); + return cidx_ts_iter_->GetValue(); + } + } + // Valid() to check row data? what if only one entry invalid? + return dummy; + } + + private: + std::string GetKeys(const hybridse::codec::Row& pkeys_pts) { + std::string pkeys, key; // RowView Get will assign output, no need to clear + for (auto pkey_idx : pkeys_idx_) { + if (!pkeys.empty()) { + pkeys += "|"; + } + // TODO(hw): if null, append to key? + auto ret = row_view_->GetStrValue(pkeys_pts.buf(), pkey_idx, &key); + if (ret == -1) { + LOG(WARNING) << "get pkey failed"; + return {}; + } + pkeys += key.empty() ? hybridse::codec::EMPTY_STRING : key; + DLOG(INFO) << pkey_idx << "=" << key; + } + return pkeys; + } + + private: + std::unique_ptr<::hybridse::codec::WindowIterator> cidx_iter_; + std::unique_ptr cidx_ts_iter_; + // for debug + std::unique_ptr row_view_; + std::vector pkeys_idx_; + + ::hybridse::codec::Row dummy; +}; + +class IOTKeyIterator : public MemTableKeyIterator { + public: + IOTKeyIterator(Segment** segments, uint32_t seg_cnt, ::openmldb::storage::TTLType ttl_type, uint64_t expire_time, + uint64_t expire_cnt, uint32_t ts_index, type::CompressType compress_type, + std::shared_ptr cidx_handler, const std::string& cidx_name) + : MemTableKeyIterator(segments, seg_cnt, ttl_type, expire_time, expire_cnt, ts_index, compress_type) { + // cidx_iter will be used by RowIterator but it's unique, so create it when get RowIterator + cidx_handler_ = cidx_handler; + cidx_name_ = cidx_name; + } + + ~IOTKeyIterator() override {} + void SetSchema(const std::shared_ptr& schema, + const std::shared_ptr& cidx) { + schema_ = *schema; // copy + // pkeys idx + std::map col_idx_map; + for (int i = 0; i < schema_.size(); i++) { + col_idx_map[schema_[i].name()] = i; + } + pkeys_idx_.clear(); + for (auto pkey : cidx->GetColumns()) { + pkeys_idx_.emplace_back(col_idx_map[pkey.GetName()]); + } + } + ::hybridse::vm::RowIterator* GetRawValue() override { + DLOG(INFO) << "GetRawValue for key " << GetKey().ToString() << ", bind cidx " << cidx_name_; + TimeEntries::Iterator* it = GetTimeIter(); + auto cidx_iter = cidx_handler_->GetWindowIterator(cidx_name_); + auto iter = + new IOTWindowIterator(it, ttl_type_, expire_time_, expire_cnt_, compress_type_, std::move(cidx_iter)); + // iter->SetSchema(schema_, pkeys_idx_); + return iter; + } + + private: + std::shared_ptr cidx_handler_; + std::string cidx_name_; + // test + codec::Schema schema_; + std::vector pkeys_idx_; +}; + +class GCEntryInfo { + public: + typedef std::pair Entry; + ~GCEntryInfo() { + for (auto& entry : entries_) { + entry.second->dim_cnt_down--; + // data block should be moved to node_cache then delete + // I don't want delete block here + LOG_IF(ERROR, entry.second->dim_cnt_down == 0) << "dim_cnt_down=0 but no delete"; + } + } + void AddEntry(const Slice& keys, uint64_t ts, storage::DataBlock* ptr) { + // to avoid Block deleted before gc, add ref + ptr->dim_cnt_down++; // TODO(hw): no concurrency? or make sure under lock + entries_.emplace_back(ts, ptr); + } + std::size_t Size() { return entries_.size(); } + std::vector& GetEntries() { return entries_; } + bool Full() { return entries_.size() >= FLAGS_cidx_gc_max_size; } + + private: + // std::vector> entries_; + std::vector entries_; +}; + +class IOTSegment : public Segment { + public: + explicit IOTSegment(uint8_t height) : Segment(height) {} + IOTSegment(uint8_t height, const std::vector& ts_idx_vec, + const std::vector& index_types) + : Segment(height, ts_idx_vec), index_types_(index_types) { + // find clustered ts id + for (uint32_t i = 0; i < ts_idx_vec.size(); i++) { + if (index_types_[i] == common::kClustered) { + clustered_ts_id_ = ts_idx_vec[i]; + break; + } + } + } + ~IOTSegment() override {} + + bool PutUnlock(const Slice& key, uint64_t time, DataBlock* row, bool put_if_absent, bool check_all_time); + bool Put(const Slice& key, const std::map& ts_map, DataBlock* cblock, DataBlock* sblock, + bool put_if_absent = false); + // use ts map to get idx in entry_arr + // no ok status, exists or not found + absl::Status CheckKeyExists(const Slice& key, const std::map& ts_map); + // DEFAULT_TS_COL_ID is uint32_t max, so clsutered_ts_id_ can't have a init value, use std::optional + bool IsClusteredTs(uint32_t ts_id) { + return clustered_ts_id_.has_value() ? (ts_id == clustered_ts_id_.value()) : false; + } + + std::optional ClusteredTs() const { return clustered_ts_id_; } + + void GrepGCEntry(const std::map& ttl_st_map, GCEntryInfo* gc_entry_info); + + // if segment is not secondary idx, use normal NewIterator in Segment + + private: + void GrepGCAllType(const std::map& ttl_st_map, GCEntryInfo* gc_entry_info); + + private: + std::vector index_types_; + std::optional clustered_ts_id_; +}; + +} // namespace openmldb::storage +#endif // SRC_STORAGE_IOT_SEGMENT_H_ diff --git a/src/storage/iot_segment_test.cc b/src/storage/iot_segment_test.cc new file mode 100644 index 00000000000..312c92c5a87 --- /dev/null +++ b/src/storage/iot_segment_test.cc @@ -0,0 +1,517 @@ +/* + * Copyright 2021 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "storage/iot_segment.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "base/glog_wrapper.h" +#include "base/slice.h" +#include "gtest/gtest.h" +#include "storage/record.h" + +using ::openmldb::base::Slice; + +namespace openmldb { +namespace storage { + +// iotsegment is not the same with segment, so we need to test it separately +class IOTSegmentTest : public ::testing::Test { + public: + IOTSegmentTest() {} + ~IOTSegmentTest() {} +}; + +TEST_F(IOTSegmentTest, PutAndScan) { + IOTSegment segment(8, {1, 3, 5}, + {common::IndexType::kClustered, common::IndexType::kSecondary, common::IndexType::kCovering}); + Slice pk("test1"); + std::string value = "test0"; + auto cblk = new DataBlock(2, value.c_str(), value.size()); // 1 clustered + 1 covering, hard copy + auto sblk = new DataBlock(1, value.c_str(), value.size()); // 1 secondary, fake value, hard copy + // use the frenquently used Put method + ASSERT_TRUE(segment.Put(pk, {{1, 100}, {3, 300}, {5, 500}}, cblk, sblk)); + // if first one is clustered index, segment put will fail in the first time, no need to revert + ASSERT_FALSE(segment.Put(pk, {{1, 100}}, cblk, sblk)); + ASSERT_FALSE(segment.Put(pk, {{1, 100}, {3, 300}, {5, 500}}, cblk, sblk)); + ASSERT_EQ(1, (int64_t)segment.GetPkCnt()); + Ticket ticket; + // iter clustered(idx 1), not the secondary, don't create iot iter + std::unique_ptr it( + segment.Segment::NewIterator("test1", 1, ticket, type::CompressType::kNoCompress)); + it->Seek(500); // find less than + ASSERT_TRUE(it->Valid()); + ASSERT_EQ(100, (int64_t)it->GetKey()); + ::openmldb::base::Slice val = it->GetValue(); + std::string result(val.data(), val.size()); + ASSERT_EQ("test0", result); + it->Next(); + ASSERT_FALSE(it->Valid()); // just one row + + // if first one is not the clustered index, we can't know if it exists, be careful + ASSERT_TRUE(segment.Put(pk, {{3, 300}, {5, 500}}, nullptr, nullptr)); +} + +TEST_F(IOTSegmentTest, PutAndScanWhenDefaultTs) { + // in the same inner index, it won't have the same ts id + IOTSegment segment(8, {DEFAULT_TS_COL_ID, 3, 5}, + {common::IndexType::kClustered, common::IndexType::kSecondary, common::IndexType::kCovering}); + Slice pk("test1"); + std::string value = "test0"; + auto cblk = new DataBlock(2, value.c_str(), value.size()); // 1 clustered + 1 covering, hard copy + auto sblk = new DataBlock(1, value.c_str(), value.size()); // 1 secondary, fake value, hard copy + // use the frenquently used Put method + ASSERT_TRUE(segment.Put(pk, {{DEFAULT_TS_COL_ID, 100}, {3, 300}, {5, 500}}, cblk, sblk)); + // if first one is clustered index, segment put will fail in the first time, no need to revert + ASSERT_FALSE(segment.Put(pk, {{DEFAULT_TS_COL_ID, 100}}, cblk, sblk)); + ASSERT_FALSE(segment.Put(pk, {{DEFAULT_TS_COL_ID, 100}, {3, 300}, {5, 500}}, cblk, sblk)); + ASSERT_EQ(1, (int64_t)segment.GetPkCnt()); + Ticket ticket; + // iter clustered(idx 1), not the secondary, don't create iot iter + std::unique_ptr it( + segment.Segment::NewIterator("test1", DEFAULT_TS_COL_ID, ticket, type::CompressType::kNoCompress)); + it->Seek(500); // find less than + ASSERT_TRUE(it->Valid()); + ASSERT_EQ(100, (int64_t)it->GetKey()); + ::openmldb::base::Slice val = it->GetValue(); + std::string result(val.data(), val.size()); + ASSERT_EQ("test0", result); + it->Next(); + ASSERT_FALSE(it->Valid()); // just one row + + // if first one is not the clustered index, we can't know if it exists, be careful + ASSERT_TRUE(segment.Put(pk, {{3, 300}, {5, 500}}, nullptr, nullptr)); +} + +TEST_F(IOTSegmentTest, CheckKeyExists) { + IOTSegment segment(8, {1, 3, 5}, + {common::IndexType::kClustered, common::IndexType::kSecondary, common::IndexType::kCovering}); + Slice pk("test1"); + std::string value = "test0"; + auto cblk = new DataBlock(2, value.c_str(), value.size()); // 1 clustered + 1 covering, hard copy + auto sblk = new DataBlock(1, value.c_str(), value.size()); // 1 secondary, fake value, hard copy + // use the frenquently used Put method + segment.Put(pk, {{1, 100}, {3, 300}, {5, 500}}, cblk, sblk); + ASSERT_EQ(1, (int64_t)segment.GetPkCnt()); + // check if exists in cidx segment(including 'ttl expired but not gc') + auto st = segment.CheckKeyExists(pk, {{1, 100}}); + ASSERT_TRUE(absl::IsAlreadyExists(st)) << st.ToString(); + st = segment.CheckKeyExists(pk, {{1, 300}}); + ASSERT_TRUE(absl::IsNotFound(st)) << st.ToString(); + // check sidx/covering idx will fail + st = segment.CheckKeyExists(pk, {{3, 300}}); + ASSERT_TRUE(absl::IsInvalidArgument(st)) << st.ToString(); +} + +// report result, don't need to print args in here, just print the failure +::testing::AssertionResult CheckStatisticsInfo(const StatisticsInfo& expect, const StatisticsInfo& value) { + if (expect.idx_cnt_vec.size() != value.idx_cnt_vec.size()) { + return ::testing::AssertionFailure() + << "idx_cnt_vec size expect " << expect.idx_cnt_vec.size() << " but got " << value.idx_cnt_vec.size(); + } + for (size_t idx = 0; idx < expect.idx_cnt_vec.size(); idx++) { + if (expect.idx_cnt_vec[idx] != value.idx_cnt_vec[idx]) { + return ::testing::AssertionFailure() << "idx_cnt_vec[" << idx << "] expect " << expect.idx_cnt_vec[idx] + << " but got " << value.idx_cnt_vec[idx]; + } + } + if (expect.record_byte_size != value.record_byte_size) { + return ::testing::AssertionFailure() + << "record_byte_size expect " << expect.record_byte_size << " but got " << value.record_byte_size; + } + if (expect.idx_byte_size != value.idx_byte_size) { + return ::testing::AssertionFailure() + << "idx_byte_size expect " << expect.idx_byte_size << " but got " << value.idx_byte_size; + } + return ::testing::AssertionSuccess(); +} + +// helper +::testing::AssertionResult CheckStatisticsInfo(std::initializer_list vec, uint64_t idx_byte_size, + uint64_t record_byte_size, const StatisticsInfo& value) { + StatisticsInfo info(0); // overwrite by set idx_cnt_vec + info.idx_cnt_vec = vec; + info.idx_byte_size = idx_byte_size; + info.record_byte_size = record_byte_size; + return CheckStatisticsInfo(info, value); +} + +StatisticsInfo CreateStatisticsInfo(uint64_t idx_cnt, uint64_t idx_byte_size, uint64_t record_byte_size) { + StatisticsInfo info(1); + info.idx_cnt_vec[0] = idx_cnt; + info.idx_byte_size = idx_byte_size; + info.record_byte_size = record_byte_size; + return info; +} + +// TODO(hw): gc multi idx has bug, fix later +// TEST_F(IOTSegmentTest, TestGc4Head) { +// IOTSegment segment(8); +// Slice pk("PK"); +// segment.Put(pk, 9768, "test1", 5); +// segment.Put(pk, 9769, "test2", 5); +// StatisticsInfo gc_info(1); +// segment.Gc4Head(1, &gc_info); +// CheckStatisticsInfo(CreateStatisticsInfo(1, 0, GetRecordSize(5)), gc_info); +// Ticket ticket; +// std::unique_ptr it(segment.NewIterator(pk, ticket, type::CompressType::kNoCompress)); +// it->Seek(9769); +// ASSERT_TRUE(it->Valid()); +// ASSERT_EQ(9769, (int64_t)it->GetKey()); +// ::openmldb::base::Slice value = it->GetValue(); +// std::string result(value.data(), value.size()); +// ASSERT_EQ("test2", result); +// it->Next(); +// ASSERT_FALSE(it->Valid()); +// } + +TEST_F(IOTSegmentTest, TestGc4TTL) { + // cidx segment won't execute gc, gc will be done in iot gc + // and multi idx gc `GcAllType` has bug, skip test it + { + std::vector idx_vec = {1}; + std::vector idx_type = {common::IndexType::kClustered}; + auto segment = std::make_unique(8, idx_vec, idx_type); + Slice pk("test1"); + std::string value = "test0"; + auto cblk = new DataBlock(1, value.c_str(), value.size()); // 1 clustered + 1 covering, hard copy + auto sblk = new DataBlock(1, value.c_str(), value.size()); // 1 secondary, fake value, hard copy + ASSERT_TRUE(segment->Put(pk, {{1, 100}}, cblk, sblk)); + // ref iot gc SchedGCByDelete + StatisticsInfo statistics_info(segment->GetTsCnt()); + segment->IncrGcVersion(); + segment->GcFreeList(&statistics_info); + segment->ExecuteGc({{1, {1, 0, TTLType::kAbsoluteTime}}}, &statistics_info, segment->ClusteredTs()); + ASSERT_TRUE(CheckStatisticsInfo({0}, 0, 0, statistics_info)); + } + { + std::vector idx_vec = {1}; + std::vector idx_type = {common::IndexType::kSecondary}; + auto segment = std::make_unique(8, idx_vec, idx_type); + Slice pk("test1"); + std::string value = "test0"; + auto cblk = new DataBlock(1, value.c_str(), value.size()); // 1 clustered + 1 covering, hard copy + // execute gc will delete it + auto sblk = new DataBlock(1, value.c_str(), value.size()); // 1 secondary, fake value, hard copy + ASSERT_TRUE(segment->Put(pk, {{1, 100}}, cblk, sblk)); + // ref iot gc SchedGCByDelete + StatisticsInfo statistics_info(segment->GetTsCnt()); + segment->IncrGcVersion(); // 1 + segment->GcFreeList(&statistics_info); + segment->ExecuteGc({{1, {1, 0, TTLType::kAbsoluteTime}}}, &statistics_info, segment->ClusteredTs()); + // secondary will gc, but idx_byte_size is 0(GcFreeList change it) + ASSERT_TRUE(CheckStatisticsInfo({1}, 0, GetRecordSize(5), statistics_info)); + + segment->IncrGcVersion(); // 2 + segment->GcFreeList(&statistics_info); // empty + ASSERT_TRUE(CheckStatisticsInfo({1}, 0, GetRecordSize(5), statistics_info)); + segment->IncrGcVersion(); // delta default is 2, version should >=2, and node_cache free version should >= 3 + segment->GcFreeList(&statistics_info); + // don't know why 197 + ASSERT_TRUE(CheckStatisticsInfo({1}, 197, GetRecordSize(5), statistics_info)); + } +} + +// TEST_F(IOTSegmentTest, TestGc4TTLAndHead) { +// IOTSegment segment(8); +// segment.Put("PK1", 9766, "test1", 5); +// segment.Put("PK1", 9767, "test2", 5); +// segment.Put("PK1", 9768, "test3", 5); +// segment.Put("PK1", 9769, "test4", 5); +// segment.Put("PK2", 9765, "test1", 5); +// segment.Put("PK2", 9766, "test2", 5); +// segment.Put("PK2", 9767, "test3", 5); +// StatisticsInfo gc_info(1); +// // Gc4TTLAndHead only change gc_info.vec[0], check code +// // no expire +// segment.Gc4TTLAndHead(0, 0, &gc_info); +// ASSERT_TRUE(CheckStatisticsInfo({0}, 0, 0, gc_info)); +// // no lat expire, so all records won't be deleted +// segment.Gc4TTLAndHead(9999, 0, &gc_info); +// ASSERT_TRUE(CheckStatisticsInfo({0}, 0, 0, gc_info)); +// // no abs expire, so all records won't be deleted +// segment.Gc4TTLAndHead(0, 3, &gc_info); +// ASSERT_TRUE(CheckStatisticsInfo({0}, 0, 0, gc_info)); +// // current_time > expire_time means not expired, so == is outdate and lat 2, so `9765` should be deleted +// segment.Gc4TTLAndHead(9765, 2, &gc_info); +// ASSERT_TRUE(CheckStatisticsInfo({1}, 0, GetRecordSize(5), gc_info)); +// // gc again, no record expired, info won't update +// segment.Gc4TTLAndHead(9765, 2, &gc_info); +// ASSERT_TRUE(CheckStatisticsInfo({1}, 0, GetRecordSize(5), gc_info)); +// // new info +// gc_info.Reset(); +// // time <= 9770 is abs expired, but lat 1, so just 1 record per key left, 4 deleted +// segment.Gc4TTLAndHead(9770, 1, &gc_info); +// ASSERT_TRUE(CheckStatisticsInfo({4}, 0, 4 * GetRecordSize(5), gc_info)); +// uint64_t cnt = 0; +// ASSERT_EQ(0, segment.GetCount("PK1", cnt)); +// ASSERT_EQ(1, cnt); +// ASSERT_EQ(0, segment.GetCount("PK2", cnt)); +// ASSERT_EQ(1, cnt); +// } + +// TEST_F(IOTSegmentTest, TestGc4TTLOrHead) { +// IOTSegment segment(8); +// segment.Put("PK1", 9766, "test1", 5); +// segment.Put("PK1", 9767, "test2", 5); +// segment.Put("PK1", 9768, "test3", 5); +// segment.Put("PK1", 9769, "test4", 5); +// segment.Put("PK2", 9765, "test1", 5); +// segment.Put("PK2", 9766, "test2", 5); +// segment.Put("PK2", 9767, "test3", 5); +// StatisticsInfo gc_info(1); +// // no expire +// segment.Gc4TTLOrHead(0, 0, &gc_info); +// ASSERT_TRUE(CheckStatisticsInfo({0}, 0, 0, gc_info)); +// // all record <= 9765 should be deleted, no matter the lat expire +// segment.Gc4TTLOrHead(9765, 0, &gc_info); +// ASSERT_TRUE(CheckStatisticsInfo({1}, 0, GetRecordSize(5), gc_info)); +// gc_info.Reset(); +// // even abs no expire, only lat 3 per key +// segment.Gc4TTLOrHead(0, 3, &gc_info); +// ASSERT_TRUE(CheckStatisticsInfo({1}, 0, GetRecordSize(5), gc_info)); +// gc_info.Reset(); +// segment.Gc4TTLOrHead(9765, 3, &gc_info); +// ASSERT_TRUE(CheckStatisticsInfo({0}, 0, 0, gc_info)); +// segment.Gc4TTLOrHead(9766, 2, &gc_info); +// ASSERT_TRUE(CheckStatisticsInfo({2}, 0, 2 * GetRecordSize(5), gc_info)); +// gc_info.Reset(); +// segment.Gc4TTLOrHead(9770, 1, &gc_info); +// ASSERT_TRUE(CheckStatisticsInfo({3}, 0, 3 * GetRecordSize(5), gc_info)); +// } + +// TEST_F(IOTSegmentTest, TestStat) { +// IOTSegment segment(8); +// segment.Put("PK", 9768, "test1", 5); +// segment.Put("PK", 9769, "test2", 5); +// ASSERT_EQ(2, (int64_t)segment.GetIdxCnt()); +// ASSERT_EQ(1, (int64_t)segment.GetPkCnt()); +// StatisticsInfo gc_info(1); +// segment.Gc4TTL(9765, &gc_info); +// ASSERT_EQ(0, gc_info.GetTotalCnt()); +// gc_info.Reset(); +// segment.Gc4TTL(9768, &gc_info); +// ASSERT_EQ(1, (int64_t)segment.GetIdxCnt()); +// ASSERT_EQ(1, gc_info.GetTotalCnt()); +// segment.Gc4TTL(9770, &gc_info); +// ASSERT_EQ(2, gc_info.GetTotalCnt()); +// ASSERT_EQ(0, (int64_t)segment.GetIdxCnt()); +// } + +// TEST_F(IOTSegmentTest, GetTsIdx) { +// std::vector ts_idx_vec = {1, 3, 5}; +// IOTSegment segment(8, ts_idx_vec); +// ASSERT_EQ(3, (int64_t)segment.GetTsCnt()); +// uint32_t real_idx = UINT32_MAX; +// ASSERT_EQ(-1, segment.GetTsIdx(0, real_idx)); +// ASSERT_EQ(0, segment.GetTsIdx(1, real_idx)); +// ASSERT_EQ(0, (int64_t)real_idx); +// ASSERT_EQ(-1, segment.GetTsIdx(2, real_idx)); +// ASSERT_EQ(0, segment.GetTsIdx(3, real_idx)); +// ASSERT_EQ(1, (int64_t)real_idx); +// ASSERT_EQ(-1, segment.GetTsIdx(4, real_idx)); +// ASSERT_EQ(0, segment.GetTsIdx(5, real_idx)); +// ASSERT_EQ(2, (int64_t)real_idx); +// } + +// int GetCount(IOTSegment* segment, int idx) { +// int count = 0; +// std::unique_ptr pk_it(segment->GetKeyEntries()->NewIterator()); +// if (!pk_it) { +// return 0; +// } +// uint32_t real_idx = idx; +// segment->GetTsIdx(idx, real_idx); +// pk_it->SeekToFirst(); +// while (pk_it->Valid()) { +// KeyEntry* entry = nullptr; +// if (segment->GetTsCnt() > 1) { +// entry = reinterpret_cast(pk_it->GetValue())[real_idx]; +// } else { +// entry = reinterpret_cast(pk_it->GetValue()); +// } +// std::unique_ptr ts_it(entry->entries.NewIterator()); +// ts_it->SeekToFirst(); +// while (ts_it->Valid()) { +// count++; +// ts_it->Next(); +// } +// pk_it->Next(); +// } +// return count; +// } + +// TEST_F(IOTSegmentTest, ReleaseAndCount) { +// std::vector ts_idx_vec = {1, 3}; +// IOTSegment segment(8, ts_idx_vec); +// ASSERT_EQ(2, (int64_t)segment.GetTsCnt()); +// for (int i = 0; i < 100; i++) { +// std::string key = "key" + std::to_string(i); +// uint64_t ts = 1669013677221000; +// for (int j = 0; j < 2; j++) { +// DataBlock* data = new DataBlock(2, key.c_str(), key.length()); +// std::map ts_map = {{1, ts + j}, {3, ts + j}}; +// segment.Put(Slice(key), ts_map, data); +// } +// } +// ASSERT_EQ(200, GetCount(&segment, 1)); +// ASSERT_EQ(200, GetCount(&segment, 3)); +// StatisticsInfo gc_info(1); +// segment.ReleaseAndCount({1}, &gc_info); +// ASSERT_EQ(0, GetCount(&segment, 1)); +// ASSERT_EQ(200, GetCount(&segment, 3)); +// segment.ReleaseAndCount(&gc_info); +// ASSERT_EQ(0, GetCount(&segment, 1)); +// ASSERT_EQ(0, GetCount(&segment, 3)); +// } + +// TEST_F(IOTSegmentTest, ReleaseAndCountOneTs) { +// IOTSegment segment(8); +// for (int i = 0; i < 100; i++) { +// std::string key = "key" + std::to_string(i); +// uint64_t ts = 1669013677221000; +// for (int j = 0; j < 2; j++) { +// segment.Put(Slice(key), ts + j, key.c_str(), key.size()); +// } +// } +// StatisticsInfo gc_info(1); +// ASSERT_EQ(200, GetCount(&segment, 0)); +// segment.ReleaseAndCount(&gc_info); +// ASSERT_EQ(0, GetCount(&segment, 0)); +// } + +// TEST_F(IOTSegmentTest, TestDeleteRange) { +// IOTSegment segment(8); +// for (int idx = 0; idx < 10; idx++) { +// std::string key = absl::StrCat("key", idx); +// std::string value = absl::StrCat("value", idx); +// uint64_t ts = 1000; +// for (int i = 0; i < 10; i++) { +// segment.Put(Slice(key), ts + i, value.data(), 6); +// } +// } +// ASSERT_EQ(100, GetCount(&segment, 0)); +// std::string pk = "key2"; +// Ticket ticket; +// std::unique_ptr it(segment.NewIterator(pk, ticket, type::CompressType::kNoCompress)); +// it->Seek(1005); +// ASSERT_TRUE(it->Valid() && it->GetKey() == 1005); +// ASSERT_TRUE(segment.Delete(std::nullopt, pk, 1005, 1004)); +// ASSERT_EQ(99, GetCount(&segment, 0)); +// it->Seek(1005); +// ASSERT_FALSE(it->Valid() && it->GetKey() == 1005); +// ASSERT_TRUE(segment.Delete(std::nullopt, pk, 1005, std::nullopt)); +// ASSERT_EQ(94, GetCount(&segment, 0)); +// it->Seek(1005); +// ASSERT_FALSE(it->Valid()); +// pk = "key3"; +// ASSERT_TRUE(segment.Delete(std::nullopt, pk)); +// pk = "key4"; +// ASSERT_TRUE(segment.Delete(std::nullopt, pk, 1005, 1001)); +// ASSERT_EQ(80, GetCount(&segment, 0)); +// segment.IncrGcVersion(); +// segment.IncrGcVersion(); +// StatisticsInfo gc_info(1); +// segment.GcFreeList(&gc_info); +// CheckStatisticsInfo(CreateStatisticsInfo(20, 1012, 20 * (6 + sizeof(DataBlock))), gc_info); +// } + +// TEST_F(IOTSegmentTest, PutIfAbsent) { +// { +// IOTSegment segment(8); // so ts_cnt_ == 1 +// // check all time == false +// segment.Put("PK", 1, "test1", 5, true); +// segment.Put("PK", 1, "test2", 5, true); // even key&time is the same, different value means different record +// ASSERT_EQ(2, (int64_t)segment.GetIdxCnt()); +// ASSERT_EQ(1, (int64_t)segment.GetPkCnt()); +// segment.Put("PK", 2, "test3", 5, true); +// segment.Put("PK", 2, "test4", 5, true); +// segment.Put("PK", 3, "test5", 5, true); +// segment.Put("PK", 3, "test6", 5, true); +// ASSERT_EQ(6, (int64_t)segment.GetIdxCnt()); +// // insert exists rows +// segment.Put("PK", 2, "test3", 5, true); +// segment.Put("PK", 1, "test1", 5, true); +// segment.Put("PK", 1, "test2", 5, true); +// segment.Put("PK", 3, "test6", 5, true); +// ASSERT_EQ(6, (int64_t)segment.GetIdxCnt()); +// // new rows +// segment.Put("PK", 2, "test7", 5, true); +// ASSERT_EQ(7, (int64_t)segment.GetIdxCnt()); +// segment.Put("PK", 0, "test8", 5, true); // seek to last, next is empty +// ASSERT_EQ(8, (int64_t)segment.GetIdxCnt()); +// } + +// { +// // support when ts_cnt_ != 1 too +// std::vector ts_idx_vec = {1, 3}; +// IOTSegment segment(8, ts_idx_vec); +// ASSERT_EQ(2, (int64_t)segment.GetTsCnt()); +// std::string key = "PK"; +// uint64_t ts = 1669013677221000; +// // the same ts +// for (int j = 0; j < 2; j++) { +// DataBlock* data = new DataBlock(2, key.c_str(), key.length()); +// std::map ts_map = {{1, ts}, {3, ts}}; +// segment.Put(Slice(key), ts_map, data, true); +// } +// ASSERT_EQ(1, GetCount(&segment, 1)); +// ASSERT_EQ(1, GetCount(&segment, 3)); +// } + +// { +// // put ts_map contains DEFAULT_TS_COL_ID +// std::vector ts_idx_vec = {DEFAULT_TS_COL_ID}; +// IOTSegment segment(8, ts_idx_vec); +// ASSERT_EQ(1, (int64_t)segment.GetTsCnt()); +// std::string key = "PK"; +// std::map ts_map = {{DEFAULT_TS_COL_ID, 100}}; // cur time == 100 +// auto* block = new DataBlock(1, "test1", 5); +// segment.Put(Slice(key), ts_map, block, true); +// ASSERT_EQ(1, GetCount(&segment, DEFAULT_TS_COL_ID)); +// ts_map = {{DEFAULT_TS_COL_ID, 200}}; +// block = new DataBlock(1, "test1", 5); +// segment.Put(Slice(key), ts_map, block, true); +// ASSERT_EQ(1, GetCount(&segment, DEFAULT_TS_COL_ID)); +// } + +// { +// // put ts_map contains DEFAULT_TS_COL_ID +// std::vector ts_idx_vec = {DEFAULT_TS_COL_ID, 1, 3}; +// IOTSegment segment(8, ts_idx_vec); +// ASSERT_EQ(3, (int64_t)segment.GetTsCnt()); +// std::string key = "PK"; +// std::map ts_map = {{DEFAULT_TS_COL_ID, 100}}; // cur time == 100 +// auto* block = new DataBlock(1, "test1", 5); +// segment.Put(Slice(key), ts_map, block, true); +// ASSERT_EQ(1, GetCount(&segment, DEFAULT_TS_COL_ID)); +// ts_map = {{DEFAULT_TS_COL_ID, 200}}; +// block = new DataBlock(1, "test1", 5); +// segment.Put(Slice(key), ts_map, block, true); +// ASSERT_EQ(1, GetCount(&segment, DEFAULT_TS_COL_ID)); +// } +// } + +} // namespace storage +} // namespace openmldb + +int main(int argc, char** argv) { + ::openmldb::base::SetLogLevel(INFO); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/storage/key_entry.cc b/src/storage/key_entry.cc index 2713510f16c..8af33e8fc70 100644 --- a/src/storage/key_entry.cc +++ b/src/storage/key_entry.cc @@ -36,7 +36,7 @@ void KeyEntry::Release(uint32_t idx, StatisticsInfo* statistics_info) { if (node->GetValue()->dim_cnt_down > 1) { node->GetValue()->dim_cnt_down--; } else { - DEBUGLOG("delele data block for key %lu", node->GetKey()); + VLOG(1) << "delete data block for key " << node->GetKey(); statistics_info->record_byte_size += GetRecordSize(node->GetValue()->size); delete node->GetValue(); } diff --git a/src/storage/mem_table.cc b/src/storage/mem_table.cc index 023974e3c6b..910fdc06e4d 100644 --- a/src/storage/mem_table.cc +++ b/src/storage/mem_table.cc @@ -17,6 +17,7 @@ #include "storage/mem_table.h" #include + #include #include @@ -26,8 +27,8 @@ #include "common/timer.h" #include "gflags/gflags.h" #include "schema/index_util.h" -#include "storage/record.h" #include "storage/mem_table_iterator.h" +#include "storage/record.h" DECLARE_uint32(skiplist_max_height); DECLARE_uint32(skiplist_max_height); @@ -54,7 +55,7 @@ MemTable::MemTable(const ::openmldb::api::TableMeta& table_meta) : Table(table_meta.storage_mode(), table_meta.name(), table_meta.tid(), table_meta.pid(), 0, true, 60 * 1000, std::map(), ::openmldb::type::TTLType::kAbsoluteTime, ::openmldb::type::CompressType::kNoCompress), - segments_(MAX_INDEX_NUM, nullptr) { + segments_(MAX_INDEX_NUM, nullptr) { seg_cnt_ = 8; enable_gc_ = true; segment_released_ = false; @@ -80,7 +81,7 @@ MemTable::~MemTable() { PDLOG(INFO, "drop memtable. tid %u pid %u", id_, pid_); } -bool MemTable::Init() { +bool MemTable::InitMeta() { key_entry_max_height_ = FLAGS_key_entry_max_height; if (!InitFromMeta()) { return false; @@ -88,21 +89,33 @@ bool MemTable::Init() { if (table_meta_->seg_cnt() > 0) { seg_cnt_ = table_meta_->seg_cnt(); } + return true; +} + +uint32_t MemTable::KeyEntryMaxHeight(const std::shared_ptr& inner_idx) { uint32_t global_key_entry_max_height = 0; if (table_meta_->has_key_entry_max_height() && table_meta_->key_entry_max_height() <= FLAGS_skiplist_max_height && table_meta_->key_entry_max_height() > 0) { global_key_entry_max_height = table_meta_->key_entry_max_height(); } + if (global_key_entry_max_height > 0) { + return global_key_entry_max_height; + } else { + return inner_idx->GetKeyEntryMaxHeight(FLAGS_absolute_default_skiplist_height, + FLAGS_latest_default_skiplist_height); + } +} +bool MemTable::Init() { + if (!InitMeta()) { + LOG(WARNING) << "init meta failed. tid " << id_ << " pid " << pid_; + return false; + } + auto inner_indexs = table_index_.GetAllInnerIndex(); for (uint32_t i = 0; i < inner_indexs->size(); i++) { const std::vector& ts_vec = inner_indexs->at(i)->GetTsIdx(); - uint32_t cur_key_entry_max_height = 0; - if (global_key_entry_max_height > 0) { - cur_key_entry_max_height = global_key_entry_max_height; - } else { - cur_key_entry_max_height = inner_indexs->at(i)->GetKeyEntryMaxHeight(FLAGS_absolute_default_skiplist_height, - FLAGS_latest_default_skiplist_height); - } + uint32_t cur_key_entry_max_height = KeyEntryMaxHeight(inner_indexs->at(i)); + Segment** seg_arr = new Segment*[seg_cnt_]; if (!ts_vec.empty()) { for (uint32_t j = 0; j < seg_cnt_; j++) { @@ -226,10 +239,8 @@ absl::Status MemTable::Put(uint64_t time, const std::string& value, const Dimens } bool MemTable::Delete(const ::openmldb::api::LogEntry& entry) { - std::optional start_ts = entry.has_ts() ? std::optional{entry.ts()} - : std::nullopt; - std::optional end_ts = entry.has_end_ts() ? std::optional{entry.end_ts()} - : std::nullopt; + std::optional start_ts = entry.has_ts() ? std::optional{entry.ts()} : std::nullopt; + std::optional end_ts = entry.has_end_ts() ? std::optional{entry.end_ts()} : std::nullopt; if (entry.dimensions_size() > 0) { for (const auto& dimension : entry.dimensions()) { if (!Delete(dimension.idx(), dimension.key(), start_ts, end_ts)) { @@ -259,8 +270,8 @@ bool MemTable::Delete(const ::openmldb::api::LogEntry& entry) { return true; } -bool MemTable::Delete(uint32_t idx, const std::string& key, - const std::optional& start_ts, const std::optional& end_ts) { +bool MemTable::Delete(uint32_t idx, const std::string& key, const std::optional& start_ts, + const std::optional& end_ts) { auto index_def = GetIndex(idx); if (!index_def || !index_def->IsReady()) { return false; @@ -336,7 +347,7 @@ void MemTable::SchedGc() { for (uint32_t k = 0; k < seg_cnt_; k++) { if (segments_[i][k] != nullptr) { StatisticsInfo statistics_info(segments_[i][k]->GetTsCnt()); - if (real_index.size() == 1 || deleting_pos.size() + deleted_num == real_index.size()) { + if (real_index.size() == 1 || deleting_pos.size() + deleted_num == real_index.size()) { segments_[i][k]->ReleaseAndCount(&statistics_info); } else { segments_[i][k]->ReleaseAndCount(deleting_pos, &statistics_info); @@ -377,8 +388,8 @@ void MemTable::SchedGc() { } consumed = ::baidu::common::timer::get_micros() - consumed; record_byte_size_.fetch_sub(gc_record_byte_size, std::memory_order_relaxed); - PDLOG(INFO, "gc finished, gc_idx_cnt %lu, consumed %lu ms for table %s tid %u pid %u", - gc_idx_cnt, consumed / 1000, name_.c_str(), id_, pid_); + PDLOG(INFO, "gc finished, gc_idx_cnt %lu, consumed %lu ms for table %s tid %u pid %u", gc_idx_cnt, consumed / 1000, + name_.c_str(), id_, pid_); UpdateTTL(); } @@ -620,18 +631,25 @@ bool MemTable::GetRecordIdxCnt(uint32_t idx, uint64_t** stat, uint32_t* size) { } bool MemTable::AddIndexToTable(const std::shared_ptr& index_def) { - std::vector ts_vec = { index_def->GetTsColumn()->GetId() }; + std::vector ts_vec = {index_def->GetTsColumn()->GetId()}; uint32_t inner_id = index_def->GetInnerPos(); Segment** seg_arr = new Segment*[seg_cnt_]; for (uint32_t j = 0; j < seg_cnt_; j++) { seg_arr[j] = new Segment(FLAGS_absolute_default_skiplist_height, ts_vec); PDLOG(INFO, "init %u, %u segment. height %u, ts col num %u. tid %u pid %u", inner_id, j, - FLAGS_absolute_default_skiplist_height, ts_vec.size(), id_, pid_); + FLAGS_absolute_default_skiplist_height, ts_vec.size(), id_, pid_); } segments_[inner_id] = seg_arr; return true; } +uint32_t MemTable::SegIdx(const std::string& pk) { + if (seg_cnt_ > 1) { + return ::openmldb::base::hash(pk.c_str(), pk.length(), SEED) % seg_cnt_; + } + return 0; +} + ::hybridse::vm::WindowIterator* MemTable::NewWindowIterator(uint32_t index) { std::shared_ptr index_def = table_index_.GetIndex(index); if (!index_def || !index_def->IsReady()) { @@ -651,8 +669,8 @@ ::hybridse::vm::WindowIterator* MemTable::NewWindowIterator(uint32_t index) { if (ts_col) { ts_idx = ts_col->GetId(); } - return new MemTableKeyIterator(segments_[real_idx], seg_cnt_, ttl->ttl_type, - expire_time, expire_cnt, ts_idx, GetCompressType()); + return new MemTableKeyIterator(segments_[real_idx], seg_cnt_, ttl->ttl_type, expire_time, expire_cnt, ts_idx, + GetCompressType()); } TraverseIterator* MemTable::NewTraverseIterator(uint32_t index) { @@ -671,11 +689,11 @@ TraverseIterator* MemTable::NewTraverseIterator(uint32_t index) { uint32_t real_idx = index_def->GetInnerPos(); auto ts_col = index_def->GetTsColumn(); if (ts_col) { - return new MemTableTraverseIterator(segments_[real_idx], seg_cnt_, ttl->ttl_type, - expire_time, expire_cnt, ts_col->GetId(), GetCompressType()); + return new MemTableTraverseIterator(segments_[real_idx], seg_cnt_, ttl->ttl_type, expire_time, expire_cnt, + ts_col->GetId(), GetCompressType()); } - return new MemTableTraverseIterator(segments_[real_idx], seg_cnt_, ttl->ttl_type, - expire_time, expire_cnt, 0, GetCompressType()); + return new MemTableTraverseIterator(segments_[real_idx], seg_cnt_, ttl->ttl_type, expire_time, expire_cnt, 0, + GetCompressType()); } bool MemTable::GetBulkLoadInfo(::openmldb::api::BulkLoadInfoResponse* response) { diff --git a/src/storage/mem_table.h b/src/storage/mem_table.h index 694203c3e40..c85ffd12da4 100644 --- a/src/storage/mem_table.h +++ b/src/storage/mem_table.h @@ -54,10 +54,10 @@ class MemTable : public Table { absl::Status Put(uint64_t time, const std::string& value, const Dimensions& dimensions, bool put_if_absent) override; - bool GetBulkLoadInfo(::openmldb::api::BulkLoadInfoResponse* response); + virtual bool GetBulkLoadInfo(::openmldb::api::BulkLoadInfoResponse* response); - bool BulkLoad(const std::vector& data_blocks, - const ::google::protobuf::RepeatedPtrField<::openmldb::api::BulkLoadIndex>& indexes); + virtual bool BulkLoad(const std::vector& data_blocks, + const ::google::protobuf::RepeatedPtrField<::openmldb::api::BulkLoadIndex>& indexes); bool Delete(const ::openmldb::api::LogEntry& entry) override; @@ -68,7 +68,7 @@ class MemTable : public Table { TraverseIterator* NewTraverseIterator(uint32_t index) override; - ::hybridse::vm::WindowIterator* NewWindowIterator(uint32_t index); + ::hybridse::vm::WindowIterator* NewWindowIterator(uint32_t index) override; // release all memory allocated uint64_t Release(); @@ -104,15 +104,26 @@ class MemTable : public Table { protected: bool AddIndexToTable(const std::shared_ptr& index_def) override; + uint32_t SegIdx(const std::string& pk); + + Segment* GetSegment(uint32_t real_idx, uint32_t seg_idx) { + // TODO(hw): protect + return segments_[real_idx][seg_idx]; + } + Segment** GetSegments(uint32_t real_idx) { return segments_[real_idx]; } + + bool InitMeta(); + uint32_t KeyEntryMaxHeight(const std::shared_ptr& inner_idx); + private: bool CheckAbsolute(const TTLSt& ttl, uint64_t ts); bool CheckLatest(uint32_t index_id, const std::string& key, uint64_t ts); - bool Delete(uint32_t idx, const std::string& key, - const std::optional& start_ts, const std::optional& end_ts); + bool Delete(uint32_t idx, const std::string& key, const std::optional& start_ts, + const std::optional& end_ts); - private: + protected: uint32_t seg_cnt_; std::vector segments_; std::atomic enable_gc_; diff --git a/src/storage/mem_table_iterator.cc b/src/storage/mem_table_iterator.cc index 22cd7964640..f508d404af7 100644 --- a/src/storage/mem_table_iterator.cc +++ b/src/storage/mem_table_iterator.cc @@ -138,7 +138,7 @@ void MemTableKeyIterator::Next() { NextPK(); } -::hybridse::vm::RowIterator* MemTableKeyIterator::GetRawValue() { +TimeEntries::Iterator* MemTableKeyIterator::GetTimeIter() { TimeEntries::Iterator* it = nullptr; if (segments_[seg_idx_]->GetTsCnt() > 1) { KeyEntry* entry = ((KeyEntry**)pk_it_->GetValue())[ts_idx_]; // NOLINT @@ -150,6 +150,11 @@ ::hybridse::vm::RowIterator* MemTableKeyIterator::GetRawValue() { ticket_.Push((KeyEntry*)pk_it_->GetValue()); // NOLINT } it->SeekToFirst(); + return it; +} + +::hybridse::vm::RowIterator* MemTableKeyIterator::GetRawValue() { + TimeEntries::Iterator* it = GetTimeIter(); return new MemTableWindowIterator(it, ttl_type_, expire_time_, expire_cnt_, compress_type_); } diff --git a/src/storage/mem_table_iterator.h b/src/storage/mem_table_iterator.h index 4b3b2514824..427cdc09100 100644 --- a/src/storage/mem_table_iterator.h +++ b/src/storage/mem_table_iterator.h @@ -18,6 +18,7 @@ #include #include + #include "storage/segment.h" #include "vm/catalog.h" @@ -27,9 +28,12 @@ namespace storage { class MemTableWindowIterator : public ::hybridse::vm::RowIterator { public: MemTableWindowIterator(TimeEntries::Iterator* it, ::openmldb::storage::TTLType ttl_type, uint64_t expire_time, - uint64_t expire_cnt, type::CompressType compress_type) - : it_(it), record_idx_(1), expire_value_(expire_time, expire_cnt, ttl_type), - row_(), compress_type_(compress_type) {} + uint64_t expire_cnt, type::CompressType compress_type) + : it_(it), + record_idx_(1), + expire_value_(expire_time, expire_cnt, ttl_type), + row_(), + compress_type_(compress_type) {} ~MemTableWindowIterator(); @@ -59,8 +63,7 @@ class MemTableWindowIterator : public ::hybridse::vm::RowIterator { class MemTableKeyIterator : public ::hybridse::vm::WindowIterator { public: MemTableKeyIterator(Segment** segments, uint32_t seg_cnt, ::openmldb::storage::TTLType ttl_type, - uint64_t expire_time, uint64_t expire_cnt, uint32_t ts_index, - type::CompressType compress_type); + uint64_t expire_time, uint64_t expire_cnt, uint32_t ts_index, type::CompressType compress_type); ~MemTableKeyIterator() override; @@ -77,10 +80,13 @@ class MemTableKeyIterator : public ::hybridse::vm::WindowIterator { const hybridse::codec::Row GetKey() override; + protected: + TimeEntries::Iterator* GetTimeIter(); + private: void NextPK(); - private: + protected: Segment** segments_; uint32_t const seg_cnt_; uint32_t seg_idx_; @@ -97,10 +103,10 @@ class MemTableKeyIterator : public ::hybridse::vm::WindowIterator { class MemTableTraverseIterator : public TraverseIterator { public: MemTableTraverseIterator(Segment** segments, uint32_t seg_cnt, ::openmldb::storage::TTLType ttl_type, - uint64_t expire_time, uint64_t expire_cnt, uint32_t ts_index, - type::CompressType compress_type); + uint64_t expire_time, uint64_t expire_cnt, uint32_t ts_index, + type::CompressType compress_type); ~MemTableTraverseIterator() override; - inline bool Valid() override; + bool Valid() override; void Next() override; void NextPK() override; void Seek(const std::string& key, uint64_t time) override; diff --git a/src/storage/node_cache.cc b/src/storage/node_cache.cc index 0f1286494e8..766dfe78be3 100644 --- a/src/storage/node_cache.cc +++ b/src/storage/node_cache.cc @@ -79,6 +79,7 @@ void NodeCache::Free(uint64_t version, StatisticsInfo* gc_info) { node1 = key_entry_node_list_.Split(version); node2 = value_node_list_.Split(version); } + DLOG(INFO) << "free version " << version << ", node1 " << node1 << ", node2 " << node2; while (node1) { auto entry_node_list = node1->GetValue(); for (auto& entry_node : *entry_node_list) { @@ -113,11 +114,11 @@ void NodeCache::FreeNode(uint32_t idx, base::Node* node, S } gc_info->IncrIdxCnt(idx); gc_info->idx_byte_size += GetRecordTsIdxSize(node->Height()); - DLOG(INFO) << "delete key " << node->GetKey() << " with height " << node->Height(); + VLOG(1) << "delete key " << node->GetKey() << " with height " << (unsigned int)node->Height(); if (node->GetValue()->dim_cnt_down > 1) { node->GetValue()->dim_cnt_down--; } else { - DLOG(INFO) << "delele data block for key " << node->GetKey(); + VLOG(1) << "delete data block for key " << node->GetKey(); gc_info->record_byte_size += GetRecordSize(node->GetValue()->size); delete node->GetValue(); } diff --git a/src/storage/record.h b/src/storage/record.h index b3b06611f90..a5ab3f651ff 100644 --- a/src/storage/record.h +++ b/src/storage/record.h @@ -18,8 +18,10 @@ #define SRC_STORAGE_RECORD_H_ #include -#include "base/slice.h" + +#include "absl/strings/str_cat.h" #include "base/skiplist.h" +#include "base/slice.h" #include "storage/key_entry.h" namespace openmldb { @@ -67,9 +69,7 @@ struct StatisticsInfo { } } - uint64_t GetIdxCnt(uint32_t idx) const { - return idx >= idx_cnt_vec.size() ? 0 : idx_cnt_vec[idx]; - } + uint64_t GetIdxCnt(uint32_t idx) const { return idx >= idx_cnt_vec.size() ? 0 : idx_cnt_vec[idx]; } uint64_t GetTotalCnt() const { uint64_t total_cnt = 0; @@ -79,6 +79,15 @@ struct StatisticsInfo { return total_cnt; } + std::string DebugString() { + std::string str; + absl::StrAppend(&str, "idx_byte_size: ", idx_byte_size, " record_byte_size: ", record_byte_size, " idx_cnt: "); + for (uint32_t i = 0; i < idx_cnt_vec.size(); i++) { + absl::StrAppend(&str, i, ":", idx_cnt_vec[i], " "); + } + return str; + } + std::vector idx_cnt_vec; uint64_t idx_byte_size = 0; uint64_t record_byte_size = 0; diff --git a/src/storage/schema.cc b/src/storage/schema.cc index 3250a047a8b..f8a9d4fa4a6 100644 --- a/src/storage/schema.cc +++ b/src/storage/schema.cc @@ -129,6 +129,21 @@ uint32_t InnerIndexSt::GetKeyEntryMaxHeight(uint32_t abs_max_height, uint32_t la return max_height; } +int64_t InnerIndexSt::ClusteredTsId() { + int64_t id = -1; + for (const auto& cur_index : index_) { + if (cur_index->IsClusteredIndex()) { + auto ts_col = cur_index->GetTsColumn(); + DLOG_ASSERT(ts_col) << "clustered index should have ts column, even auto gen"; + if (ts_col) { + id = ts_col->GetId(); + } + } + } + return id; +} + + TableIndex::TableIndex() { indexs_ = std::make_shared>>(); inner_indexs_ = std::make_shared>>(); @@ -195,7 +210,8 @@ int TableIndex::ParseFromMeta(const ::openmldb::api::TableMeta& table_meta) { } } } - uint32_t key_idx = 0; + + // pos == idx for (int pos = 0; pos < table_meta.column_key_size(); pos++) { const auto& column_key = table_meta.column_key(pos); std::string name = column_key.index_name(); @@ -209,8 +225,10 @@ int TableIndex::ParseFromMeta(const ::openmldb::api::TableMeta& table_meta) { for (const auto& cur_col_name : column_key.col_name()) { col_vec.push_back(*(col_map[cur_col_name])); } - auto index = std::make_shared(column_key.index_name(), key_idx, status, - ::openmldb::type::IndexType::kTimeSerise, col_vec); + // index type is optional + common::IndexType index_type = column_key.has_type() ? column_key.type() : common::IndexType::kCovering; + auto index = std::make_shared(column_key.index_name(), pos, status, + ::openmldb::type::IndexType::kTimeSerise, col_vec, index_type); if (!column_key.ts_name().empty()) { const std::string& ts_name = column_key.ts_name(); index->SetTsColumn(col_map[ts_name]); @@ -226,7 +244,6 @@ int TableIndex::ParseFromMeta(const ::openmldb::api::TableMeta& table_meta) { DLOG(WARNING) << "add index failed"; return -1; } - key_idx++; } } // add default dimension diff --git a/src/storage/schema.h b/src/storage/schema.h index 9edc6e54b2a..39ee5891700 100644 --- a/src/storage/schema.h +++ b/src/storage/schema.h @@ -24,6 +24,7 @@ #include #include +#include "base/glog_wrapper.h" #include "common/timer.h" #include "proto/name_server.pb.h" #include "proto/tablet.pb.h" @@ -35,13 +36,7 @@ static constexpr uint32_t MAX_INDEX_NUM = 200; static constexpr uint32_t DEFAULT_TS_COL_ID = UINT32_MAX; static constexpr const char* DEFAULT_TS_COL_NAME = "___default_ts___"; -enum TTLType { - kAbsoluteTime = 1, - kRelativeTime = 2, - kLatestTime = 3, - kAbsAndLat = 4, - kAbsOrLat = 5 -}; +enum TTLType { kAbsoluteTime = 1, kRelativeTime = 2, kLatestTime = 3, kAbsAndLat = 4, kAbsOrLat = 5 }; // ttl unit: millisecond struct TTLSt { @@ -147,8 +142,7 @@ struct TTLSt { }; struct ExpiredChecker { - ExpiredChecker(uint64_t abs, uint64_t lat, TTLType type) : - abs_expired_ttl(abs), lat_ttl(lat), ttl_type(type) {} + ExpiredChecker(uint64_t abs, uint64_t lat, TTLType type) : abs_expired_ttl(abs), lat_ttl(lat), ttl_type(type) {} bool IsExpired(uint64_t abs, uint32_t record_idx) const { switch (ttl_type) { case TTLType::kAbsoluteTime: @@ -234,6 +228,11 @@ class IndexDef { IndexDef(const std::string& name, uint32_t id, IndexStatus status); IndexDef(const std::string& name, uint32_t id, const IndexStatus& status, ::openmldb::type::IndexType type, const std::vector& column_idx_map); + IndexDef(const std::string& name, uint32_t id, const IndexStatus& status, ::openmldb::type::IndexType type, + const std::vector& column_idx_map, common::IndexType index_type) + : IndexDef(name, id, status, type, column_idx_map) { + index_type_ = index_type; + } const std::string& GetName() const { return name_; } inline const std::shared_ptr& GetTsColumn() const { return ts_column_; } void SetTsColumn(const std::shared_ptr& ts_column) { ts_column_ = ts_column; } @@ -250,15 +249,22 @@ class IndexDef { inline uint32_t GetInnerPos() const { return inner_pos_; } ::openmldb::common::ColumnKey GenColumnKey(); + common::IndexType GetIndexType() const { return index_type_; } + bool IsSecondaryIndex() { return index_type_ == common::IndexType::kSecondary; } + bool IsClusteredIndex() { return index_type_ == common::IndexType::kClustered; } + private: std::string name_; uint32_t index_id_; uint32_t inner_pos_; std::atomic status_; + // for compatible, type is only kTimeSerise ::openmldb::type::IndexType type_; std::vector columns_; std::shared_ptr ttl_st_; std::shared_ptr ts_column_; + // 0 covering, 1 clustered, 2 secondary, default 0 + common::IndexType index_type_ = common::IndexType::kCovering; }; class InnerIndexSt { @@ -270,11 +276,22 @@ class InnerIndexSt { ts_.push_back(ts_col->GetId()); } } + LOG_IF(DFATAL, ts_.size() != index_.size()) << "ts size not equal to index size"; } inline uint32_t GetId() const { return id_; } inline const std::vector& GetTsIdx() const { return ts_; } + // len(ts) == len(type) + inline std::vector GetTsIdxType() const { + std::vector ts_idx_type; + for (const auto& cur_index : index_) { + if (cur_index->GetTsColumn()) ts_idx_type.push_back(cur_index->GetIndexType()); + } + return ts_idx_type; + } inline const std::vector>& GetIndex() const { return index_; } uint32_t GetKeyEntryMaxHeight(uint32_t abs_max_height, uint32_t lat_max_height) const; + // -1 means no clustered idx in here, it's safe to cvt to uint32_t when id >= 0 + int64_t ClusteredTsId(); private: const uint32_t id_; diff --git a/src/storage/segment.cc b/src/storage/segment.cc index 6eb721d353c..18a47c961c4 100644 --- a/src/storage/segment.cc +++ b/src/storage/segment.cc @@ -175,6 +175,7 @@ bool Segment::PutUnlock(const Slice& key, uint64_t time, DataBlock* row, bool pu reinterpret_cast(entry)->count_.fetch_add(1, std::memory_order_relaxed); byte_size += GetRecordTsIdxSize(height); idx_byte_size_.fetch_add(byte_size, std::memory_order_relaxed); + DLOG(INFO) << "idx_byte_size_ " << idx_byte_size_ << " after add " << byte_size; return true; } @@ -251,6 +252,7 @@ bool Segment::Put(const Slice& key, const std::map& ts_map, D entry->count_.fetch_add(1, std::memory_order_relaxed); byte_size += GetRecordTsIdxSize(height); idx_byte_size_.fetch_add(byte_size, std::memory_order_relaxed); + DLOG(INFO) << "idx_byte_size_ " << idx_byte_size_ << " after add " << byte_size; idx_cnt_vec_[pos->second]->fetch_add(1, std::memory_order_relaxed); } return true; @@ -268,11 +270,13 @@ bool Segment::Delete(const std::optional& idx, const Slice& key) { entry_node = entries_->Remove(key); } if (entry_node != nullptr) { + DLOG(INFO) << "add key " << key.ToString() << " to node cache. version " << gc_version_; node_cache_.AddKeyEntryNode(gc_version_.load(std::memory_order_relaxed), entry_node); return true; } } else { base::Node* data_node = nullptr; + ::openmldb::base::Node* entry_node = nullptr; { std::lock_guard lock(mu_); void* entry_arr = nullptr; @@ -286,10 +290,24 @@ bool Segment::Delete(const std::optional& idx, const Slice& key) { uint64_t ts = it->GetKey(); data_node = key_entry->entries.Split(ts); } + bool is_empty = true; + for (uint32_t i = 0; i < ts_cnt_; i++) { + if (!reinterpret_cast(entry_arr)[i]->entries.IsEmpty()) { + is_empty = false; + break; + } + } + if (is_empty) { + entry_node = entries_->Remove(key); + } } if (data_node != nullptr) { node_cache_.AddValueNodeList(ts_idx, gc_version_.load(std::memory_order_relaxed), data_node); } + if (entry_node != nullptr) { + DLOG(INFO) << "add key " << key.ToString() << " to node cache. version " << gc_version_; + node_cache_.AddKeyEntryNode(gc_version_.load(std::memory_order_relaxed), entry_node); + } } return true; } @@ -313,12 +331,13 @@ bool Segment::GetTsIdx(const std::optional& idx, uint32_t* ts_idx) { return true; } -bool Segment::Delete(const std::optional& idx, const Slice& key, - uint64_t ts, const std::optional& end_ts) { +bool Segment::Delete(const std::optional& idx, const Slice& key, uint64_t ts, + const std::optional& end_ts) { uint32_t ts_idx = 0; if (!GetTsIdx(idx, &ts_idx)) { return false; } + void* entry = nullptr; if (entries_->Get(key, entry) < 0 || entry == nullptr) { return true; @@ -351,14 +370,33 @@ bool Segment::Delete(const std::optional& idx, const Slice& key, } } base::Node* data_node = nullptr; + base::Node* entry_node = nullptr; { std::lock_guard lock(mu_); data_node = key_entry->entries.Split(ts); - DLOG(INFO) << "entry " << key.ToString() << " split by " << ts; + DLOG(INFO) << "after delete, entry " << key.ToString() << " split by " << ts; + bool is_empty = true; + if (ts_cnt_ == 1) { + is_empty = key_entry->entries.IsEmpty(); + } else { + for (uint32_t i = 0; i < ts_cnt_; i++) { + if (!reinterpret_cast(entry)[i]->entries.IsEmpty()) { + is_empty = false; + break; + } + } + } + if (is_empty) { + entry_node = entries_->Remove(key); + } } if (data_node != nullptr) { node_cache_.AddValueNodeList(ts_idx, gc_version_.load(std::memory_order_relaxed), data_node); } + if (entry_node != nullptr) { + DLOG(INFO) << "add key " << key.ToString() << " to node cache. version " << gc_version_; + node_cache_.AddKeyEntryNode(gc_version_.load(std::memory_order_relaxed), entry_node); + } return true; } @@ -368,12 +406,13 @@ void Segment::FreeList(uint32_t ts_idx, ::openmldb::base::NodeIncrIdxCnt(ts_idx); ::openmldb::base::Node* tmp = node; idx_byte_size_.fetch_sub(GetRecordTsIdxSize(tmp->Height())); + DLOG(INFO) << "idx_byte_size_ " << idx_byte_size_ << " after sub " << GetRecordTsIdxSize(tmp->Height()); node = node->GetNextNoBarrier(0); - DEBUGLOG("delete key %lu with height %u", tmp->GetKey(), tmp->Height()); + VLOG(1) << "delete key " << tmp->GetKey() << " with height " << (unsigned int)tmp->Height(); if (tmp->GetValue()->dim_cnt_down > 1) { tmp->GetValue()->dim_cnt_down--; } else { - DEBUGLOG("delele data block for key %lu", tmp->GetKey()); + VLOG(1) << "delete data block for key " << tmp->GetKey(); statistics_info->record_byte_size += GetRecordSize(tmp->GetValue()->size); delete tmp->GetValue(); } @@ -387,12 +426,17 @@ void Segment::GcFreeList(StatisticsInfo* statistics_info) { return; } StatisticsInfo old = *statistics_info; + DLOG(INFO) << "cur " << old.DebugString(); uint64_t free_list_version = cur_version - FLAGS_gc_deleted_pk_version_delta; node_cache_.Free(free_list_version, statistics_info); + DLOG(INFO) << "after node cache free " << statistics_info->DebugString(); for (size_t idx = 0; idx < idx_cnt_vec_.size(); idx++) { idx_cnt_vec_[idx]->fetch_sub(statistics_info->GetIdxCnt(idx) - old.GetIdxCnt(idx), std::memory_order_relaxed); } + idx_byte_size_.fetch_sub(statistics_info->idx_byte_size - old.idx_byte_size); + DLOG(INFO) << "idx_byte_size_ " << idx_byte_size_ << " after sub " + << statistics_info->idx_byte_size - old.idx_byte_size; } void Segment::ExecuteGc(const TTLSt& ttl_st, StatisticsInfo* statistics_info) { @@ -434,11 +478,16 @@ void Segment::ExecuteGc(const TTLSt& ttl_st, StatisticsInfo* statistics_info) { } } -void Segment::ExecuteGc(const std::map& ttl_st_map, StatisticsInfo* statistics_info) { +void Segment::ExecuteGc(const std::map& ttl_st_map, StatisticsInfo* statistics_info, + std::optional clustered_ts_id) { if (ttl_st_map.empty()) { return; } if (ts_cnt_ <= 1) { + if (clustered_ts_id.has_value() && ts_idx_map_.begin()->first == clustered_ts_id.value()) { + DLOG(INFO) << "skip normal gc in cidx"; + return; + } ExecuteGc(ttl_st_map.begin()->second, statistics_info); return; } @@ -454,7 +503,7 @@ void Segment::ExecuteGc(const std::map& ttl_st_map, StatisticsI if (!need_gc) { return; } - GcAllType(ttl_st_map, statistics_info); + GcAllType(ttl_st_map, statistics_info, clustered_ts_id); } void Segment::Gc4Head(uint64_t keep_cnt, StatisticsInfo* statistics_info) { @@ -485,11 +534,16 @@ void Segment::Gc4Head(uint64_t keep_cnt, StatisticsInfo* statistics_info) { idx_cnt_vec_[0]->fetch_sub(statistics_info->GetIdxCnt(0) - old, std::memory_order_relaxed); } -void Segment::GcAllType(const std::map& ttl_st_map, StatisticsInfo* statistics_info) { +void Segment::GcAllType(const std::map& ttl_st_map, StatisticsInfo* statistics_info, + std::optional clustered_ts_id) { uint64_t old = statistics_info->GetTotalCnt(); uint64_t consumed = ::baidu::common::timer::get_micros(); std::unique_ptr it(entries_->NewIterator()); it->SeekToFirst(); + for (auto [ts, ttl_st] : ttl_st_map) { + DLOG(INFO) << "ts " << ts << " ttl_st " << ttl_st.ToString() << " it will be current time - ttl?"; + } + while (it->Valid()) { KeyEntry** entry_arr = reinterpret_cast(it->GetValue()); Slice key = it->GetKey(); @@ -501,6 +555,11 @@ void Segment::GcAllType(const std::map& ttl_st_map, StatisticsI } auto pos = ts_idx_map_.find(kv.first); if (pos == ts_idx_map_.end() || pos->second >= ts_cnt_) { + LOG(WARNING) << ""; + continue; + } + if (clustered_ts_id.has_value() && kv.first == clustered_ts_id.value()) { + DLOG(INFO) << "skip normal gc in cidx"; continue; } KeyEntry* entry = entry_arr[pos->second]; @@ -592,12 +651,13 @@ void Segment::GcAllType(const std::map& ttl_st_map, StatisticsI } } if (entry_node != nullptr) { + DLOG(INFO) << "add key " << key.ToString() << " to node cache. version " << gc_version_; node_cache_.AddKeyEntryNode(gc_version_.load(std::memory_order_relaxed), entry_node); } } } - DEBUGLOG("[GcAll] segment gc consumed %lu, count %lu", (::baidu::common::timer::get_micros() - consumed) / 1000, - statistics_info->GetTotalCnt() - old); + DLOG(INFO) << "[GcAll] segment gc consumed " << (::baidu::common::timer::get_micros() - consumed) / 1000 + << "ms, count " << statistics_info->GetTotalCnt() - old; } void Segment::SplitList(KeyEntry* entry, uint64_t ts, ::openmldb::base::Node** node) { @@ -745,6 +805,7 @@ void Segment::Gc4TTLOrHead(const uint64_t time, const uint64_t keep_cnt, Statist } } if (entry_node != nullptr) { + DLOG(INFO) << "add key " << key.ToString() << " to node cache. version " << gc_version_; node_cache_.AddKeyEntryNode(gc_version_.load(std::memory_order_relaxed), entry_node); } uint64_t cur_idx_cnt = statistics_info->GetIdxCnt(0); diff --git a/src/storage/segment.h b/src/storage/segment.h index 511df69e5c4..01a76374889 100644 --- a/src/storage/segment.h +++ b/src/storage/segment.h @@ -46,6 +46,7 @@ class MemTableIterator : public TableIterator { void Seek(const uint64_t time) override; bool Valid() override; void Next() override; + // GetXXX will core if it_==nullptr, don't use it without valid openmldb::base::Slice GetValue() const override; uint64_t GetKey() const override; void SeekToFirst() override; @@ -68,7 +69,7 @@ class Segment { public: explicit Segment(uint8_t height); Segment(uint8_t height, const std::vector& ts_idx_vec); - ~Segment(); + virtual ~Segment(); // legacy interface called by memtable and ut void Put(const Slice& key, uint64_t time, const char* data, uint32_t size, bool put_if_absent = false, @@ -78,22 +79,25 @@ class Segment { void BulkLoadPut(unsigned int key_entry_id, const Slice& key, uint64_t time, DataBlock* row); // main put method - bool Put(const Slice& key, const std::map& ts_map, DataBlock* row, bool put_if_absent = false); + virtual bool Put(const Slice& key, const std::map& ts_map, DataBlock* row, + bool put_if_absent = false); bool Delete(const std::optional& idx, const Slice& key); - bool Delete(const std::optional& idx, const Slice& key, - uint64_t ts, const std::optional& end_ts); + bool Delete(const std::optional& idx, const Slice& key, uint64_t ts, + const std::optional& end_ts); void Release(StatisticsInfo* statistics_info); void ExecuteGc(const TTLSt& ttl_st, StatisticsInfo* statistics_info); - void ExecuteGc(const std::map& ttl_st_map, StatisticsInfo* statistics_info); + void ExecuteGc(const std::map& ttl_st_map, StatisticsInfo* statistics_info, + std::optional clustered_ts_id = std::nullopt); void Gc4TTL(const uint64_t time, StatisticsInfo* statistics_info); void Gc4Head(uint64_t keep_cnt, StatisticsInfo* statistics_info); void Gc4TTLAndHead(const uint64_t time, const uint64_t keep_cnt, StatisticsInfo* statistics_info); void Gc4TTLOrHead(const uint64_t time, const uint64_t keep_cnt, StatisticsInfo* statistics_info); - void GcAllType(const std::map& ttl_st_map, StatisticsInfo* statistics_info); + void GcAllType(const std::map& ttl_st_map, StatisticsInfo* statistics_info, + std::optional clustered_ts_id = std::nullopt); MemTableIterator* NewIterator(const Slice& key, Ticket& ticket, type::CompressType compress_type); // NOLINT MemTableIterator* NewIterator(const Slice& key, uint32_t idx, Ticket& ticket, // NOLINT @@ -141,17 +145,17 @@ class Segment { void ReleaseAndCount(const std::vector& id_vec, StatisticsInfo* statistics_info); - private: + protected: void FreeList(uint32_t ts_idx, ::openmldb::base::Node* node, StatisticsInfo* statistics_info); void SplitList(KeyEntry* entry, uint64_t ts, ::openmldb::base::Node** node); bool GetTsIdx(const std::optional& idx, uint32_t* ts_idx); bool ListContains(KeyEntry* entry, uint64_t time, DataBlock* row, bool check_all_time); - bool PutUnlock(const Slice& key, uint64_t time, DataBlock* row, bool put_if_absent = false, - bool check_all_time = false); + virtual bool PutUnlock(const Slice& key, uint64_t time, DataBlock* row, bool put_if_absent = false, + bool check_all_time = false); - private: + protected: KeyEntries* entries_; std::mutex mu_; std::atomic idx_byte_size_; @@ -159,6 +163,7 @@ class Segment { uint8_t key_entry_max_height_; uint32_t ts_cnt_; std::atomic gc_version_; + // std::map ts_idx_map_; std::vector>> idx_cnt_vec_; uint64_t ttl_offset_; diff --git a/src/storage/table.cc b/src/storage/table.cc index 7126430a9d5..ebb27bf73ef 100644 --- a/src/storage/table.cc +++ b/src/storage/table.cc @@ -207,8 +207,10 @@ bool Table::AddIndex(const ::openmldb::common::ColumnKey& column_key) { } col_vec.push_back(it->second); } + + common::IndexType index_type = column_key.has_type() ? column_key.type() : common::IndexType::kCovering; index_def = std::make_shared(column_key.index_name(), table_index_.GetMaxIndexId() + 1, - IndexStatus::kReady, ::openmldb::type::IndexType::kTimeSerise, col_vec); + IndexStatus::kReady, ::openmldb::type::IndexType::kTimeSerise, col_vec, index_type); if (!column_key.ts_name().empty()) { if (auto ts_iter = schema.find(column_key.ts_name()); ts_iter == schema.end()) { PDLOG(WARNING, "not found ts_name[%s]. tid %u pid %u", column_key.ts_name().c_str(), id_, pid_); diff --git a/src/storage/table_iterator_test.cc b/src/storage/table_iterator_test.cc index 3af20940266..47847498e0d 100644 --- a/src/storage/table_iterator_test.cc +++ b/src/storage/table_iterator_test.cc @@ -152,7 +152,8 @@ TEST_P(TableIteratorTest, latest) { dim->set_key(key); std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - table->Put(0, value, request.dimensions()); + auto st = table->Put(0, value, request.dimensions()); + ASSERT_TRUE(st.ok()) << st.ToString(); } } ::hybridse::vm::WindowIterator* it = table->NewWindowIterator(0); @@ -216,7 +217,8 @@ TEST_P(TableIteratorTest, smoketest2) { dim->set_key(key); std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - table->Put(0, value, request.dimensions()); + auto st = table->Put(0, value, request.dimensions()); + ASSERT_TRUE(st.ok()) << st.ToString(); } } ::hybridse::vm::WindowIterator* it = table->NewWindowIterator(0); @@ -383,7 +385,8 @@ TEST_P(TableIteratorTest, releaseKeyIterator) { dim->set_key(key); std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - table->Put(0, value, request.dimensions()); + auto st = table->Put(0, value, request.dimensions()); + ASSERT_TRUE(st.ok()) << st.ToString(); } } @@ -429,7 +432,8 @@ TEST_P(TableIteratorTest, SeekNonExistent) { dim->set_key(key); std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - table->Put(0, value, request.dimensions()); + auto st = table->Put(0, value, request.dimensions()); + ASSERT_TRUE(st.ok()) << st.ToString(); } } diff --git a/src/tablet/tablet_impl.cc b/src/tablet/tablet_impl.cc index 230b5c46a09..45322ca03d5 100644 --- a/src/tablet/tablet_impl.cc +++ b/src/tablet/tablet_impl.cc @@ -18,6 +18,7 @@ #include #include + #include #include #include "vm/sql_compiler.h" @@ -35,8 +36,6 @@ #include "absl/cleanup/cleanup.h" #include "absl/time/clock.h" #include "absl/time/time.h" -#include "boost/bind.hpp" -#include "boost/container/deque.hpp" #include "base/file_util.h" #include "base/glog_wrapper.h" #include "base/hash.h" @@ -45,6 +44,8 @@ #include "base/status.h" #include "base/strings.h" #include "base/sys_info.h" +#include "boost/bind.hpp" +#include "boost/container/deque.hpp" #include "brpc/controller.h" #include "butil/iobuf.h" #include "codec/codec.h" @@ -63,6 +64,7 @@ #include "schema/schema_adapter.h" #include "storage/binlog.h" #include "storage/disk_table_snapshot.h" +#include "storage/index_organized_table.h" #include "storage/segment.h" #include "storage/table.h" #include "tablet/file_sender.h" @@ -202,7 +204,7 @@ bool TabletImpl::Init(const std::string& zk_cluster, const std::string& zk_path, if (!zk_cluster.empty()) { zk_client_ = new ZkClient(zk_cluster, real_endpoint, FLAGS_zk_session_timeout, endpoint, zk_path, - FLAGS_zk_auth_schema, FLAGS_zk_cert); + FLAGS_zk_auth_schema, FLAGS_zk_cert); bool ok = zk_client_->Init(); if (!ok) { PDLOG(ERROR, "fail to init zookeeper with cluster %s", zk_cluster.c_str()); @@ -375,8 +377,8 @@ void TabletImpl::UpdateTTL(RpcController* ctrl, const ::openmldb::api::UpdateTTL base::SetResponseStatus(base::ReturnCode::kWriteDataFailed, "write meta data failed", response); return; } - PDLOG(INFO, "update table tid %u pid %u ttl meta to abs_ttl %lu lat_ttl %lu index_name %s", tid, pid, abs_ttl, lat_ttl, - index_name.c_str()); + PDLOG(INFO, "update table tid %u pid %u ttl meta to abs_ttl %lu lat_ttl %lu index_name %s", tid, pid, abs_ttl, + lat_ttl, index_name.c_str()); response->set_code(::openmldb::base::ReturnCode::kOk); response->set_msg("ok"); } @@ -465,7 +467,7 @@ int32_t TabletImpl::GetIndex(const ::openmldb::api::GetRequest* request, const : const std::map>& vers_schema, CombineIterator* it, std::string* value, uint64_t* ts) { if (it == nullptr || value == nullptr || ts == nullptr) { - PDLOG(WARNING, "invalid args"); + LOG(WARNING) << "invalid args"; return -1; } uint64_t st = request->ts(); @@ -473,10 +475,12 @@ int32_t TabletImpl::GetIndex(const ::openmldb::api::GetRequest* request, const : uint64_t et = request->et(); const openmldb::api::GetType& et_type = request->et_type(); if (st_type == ::openmldb::api::kSubKeyEq && et_type == ::openmldb::api::kSubKeyEq && st != et) { + LOG(WARNING) << "invalid args for st " << st << " not equal to et " << et; return -1; } ::openmldb::api::GetType real_et_type = et_type; ::openmldb::storage::TTLType ttl_type = it->GetTTLType(); + uint64_t expire_time = it->GetExpireTime(); if (ttl_type == ::openmldb::storage::TTLType::kAbsoluteTime || ttl_type == ::openmldb::storage::TTLType::kAbsOrLat) { @@ -485,22 +489,28 @@ int32_t TabletImpl::GetIndex(const ::openmldb::api::GetRequest* request, const : if (et < expire_time && et_type == ::openmldb::api::GetType::kSubKeyGt) { real_et_type = ::openmldb::api::GetType::kSubKeyGe; } + DLOG(INFO) << "expire time " << expire_time << ", after adjust: et " << et << " real_et_type " << real_et_type; + bool enable_project = false; openmldb::codec::RowProject row_project(vers_schema, request->projection()); if (request->projection().size() > 0) { bool ok = row_project.Init(); if (!ok) { - PDLOG(WARNING, "invalid project list"); + LOG(WARNING) << "invalid project list"; return -1; } enable_project = true; } + // it's ok when st < et(after adjust), we should return 0 rows cuz no valid data for this range + // but we have set the code -1, don't change the return code, accept it. if (st > 0 && st < et) { - DEBUGLOG("invalid args for st %lu less than et %lu or expire time %lu", st, et, expire_time); + DLOG(WARNING) << "invalid args for st " << st << " less than et " << et; return -1; } + DLOG(INFO) << "it valid " << it->Valid(); if (it->Valid()) { *ts = it->GetTs(); + DLOG(INFO) << "check " << *ts << " " << st << " " << et << " " << st_type << " " << real_et_type; if (st_type == ::openmldb::api::GetType::kSubKeyEq && st > 0 && *ts != st) { return 1; } @@ -514,7 +524,7 @@ int32_t TabletImpl::GetIndex(const ::openmldb::api::GetRequest* request, const : const int8_t* row_ptr = reinterpret_cast(data.data()); bool ok = row_project.Project(row_ptr, data.size(), &ptr, &size); if (!ok) { - PDLOG(WARNING, "fail to make a projection"); + LOG(WARNING) << "fail to make a projection"; return -4; } value->assign(reinterpret_cast(ptr), size); @@ -544,7 +554,7 @@ int32_t TabletImpl::GetIndex(const ::openmldb::api::GetRequest* request, const : break; default: - PDLOG(WARNING, "invalid et type %s", ::openmldb::api::GetType_Name(et_type).c_str()); + LOG(WARNING) << "invalid et type " << ::openmldb::api::GetType_Name(et_type).c_str(); return -2; } if (jump_out) { @@ -557,7 +567,7 @@ int32_t TabletImpl::GetIndex(const ::openmldb::api::GetRequest* request, const : const int8_t* row_ptr = reinterpret_cast(data.data()); bool ok = row_project.Project(row_ptr, data.size(), &ptr, &size); if (!ok) { - PDLOG(WARNING, "fail to make a projection"); + LOG(WARNING) << "fail to make a projection"; return -4; } value->assign(reinterpret_cast(ptr), size); @@ -671,6 +681,7 @@ void TabletImpl::Get(RpcController* controller, const ::openmldb::api::GetReques int32_t code = GetIndex(request, *table_meta, vers_schema, &combine_it, value, &ts); response->set_ts(ts); response->set_code(code); + DLOG(WARNING) << "get key " << request->key() << " ts " << ts << " code " << code; uint64_t end_time = ::baidu::common::timer::get_micros(); if (start_time + FLAGS_query_slow_log_threshold < end_time) { std::string index_name; @@ -750,10 +761,43 @@ void TabletImpl::Put(RpcController* controller, const ::openmldb::api::PutReques response->set_msg("invalid dimension parameter"); return; } - DLOG(INFO) << "put data to tid " << tid << " pid " << pid << " with key " << request->dimensions(0).key(); - // 1. normal put: ok, invalid data - // 2. put if absent: ok, exists but ignore, invalid data - st = table->Put(entry.ts(), entry.value(), entry.dimensions(), request->put_if_absent()); + if (request->check_exists()) { + // table should be iot + auto iot = std::dynamic_pointer_cast(table); + if (!iot) { + response->set_code(::openmldb::base::ReturnCode::kTableMetaIsIllegal); + response->set_msg("table type is not iot"); + return; + } + DLOG(INFO) << "check data exists in tid " << tid << " pid " << pid << " with key " + << entry.dimensions(0).key() << " ts " << entry.ts(); + // ts is ts value when check exists + st = iot->CheckDataExists(entry.ts(), entry.dimensions()); + } else { + DLOG(INFO) << "put data to tid " << tid << " pid " << pid << " with key " << request->dimensions(0).key(); + // 1. normal put: ok, invalid data + // 2. put if absent: ok, exists but ignore, invalid data + st = table->Put(entry.ts(), entry.value(), entry.dimensions(), request->put_if_absent()); + } + } + // when check exists, we won't do log + if (request->check_exists()) { + DLOG_ASSERT(request->check_exists()) << "check_exists should be true"; + DLOG_ASSERT(!request->put_if_absent()) << "put_if_absent should be false"; + DLOG(INFO) << "result " << st.ToString(); + // return ok if exists + if (absl::IsAlreadyExists(st)) { + response->set_code(base::ReturnCode::kOk); + response->set_msg("exists"); + } else if (absl::IsNotFound(st)) { + response->set_code(base::ReturnCode::kKeyNotFound); + response->set_msg(st.ToString()); + } else { + // other errors + response->set_code(base::ReturnCode::kError); + response->set_msg(st.ToString()); + } + return; } if (!st.ok()) { @@ -1333,7 +1377,7 @@ void TabletImpl::Traverse(RpcController* controller, const ::openmldb::api::Trav } base::Status TabletImpl::CheckTable(uint32_t tid, uint32_t pid, bool check_leader, - const std::shared_ptr& table) { + const std::shared_ptr
& table) { if (!table) { PDLOG(WARNING, "table does not exist. tid %u, pid %u", tid, pid); return {base::ReturnCode::kTableIsNotExist, "table does not exist"}; @@ -1350,15 +1394,16 @@ base::Status TabletImpl::CheckTable(uint32_t tid, uint32_t pid, bool check_leade } base::Status TabletImpl::DeleteAllIndex(const std::shared_ptr& table, - const std::shared_ptr& cur_index, - const std::string& key, - std::optional start_ts, - std::optional end_ts, + const std::shared_ptr& cur_index, const std::string& key, + std::optional start_ts, std::optional end_ts, bool skip_cur_ts_col, const std::shared_ptr& client_manager, uint32_t partition_num) { storage::Ticket ticket; std::unique_ptr iter(table->NewIterator(cur_index->GetId(), key, ticket)); + DLOG(INFO) << "delete all index in " << table->GetId() << "." << cur_index->GetId() << ", key " << key + << ", start_ts " << (start_ts.has_value() ? std::to_string(start_ts.value()) : "-1") << ", end_ts " + << (end_ts.has_value() ? std::to_string(end_ts.value()) : "-1"); if (start_ts.has_value()) { iter->Seek(start_ts.value()); } else { @@ -1366,7 +1411,7 @@ base::Status TabletImpl::DeleteAllIndex(const std::shared_ptr& t } auto indexs = table->GetAllIndex(); while (iter->Valid()) { - DEBUGLOG("cur ts %lu cur index pos %u", iter->GetKey(), cur_index->GetId()); + DLOG(INFO) << "cur ts " << iter->GetKey(); if (end_ts.has_value() && iter->GetKey() <= end_ts.value()) { break; } @@ -1450,14 +1495,13 @@ base::Status TabletImpl::DeleteAllIndex(const std::shared_ptr& t if (client == nullptr) { return {base::ReturnCode::kDeleteFailed, absl::StrCat("client is nullptr, pid ", cur_pid)}; } - DEBUGLOG("delete idx %u pid %u pk %s ts %lu end_ts %lu", - option.idx.value(), cur_pid, option.key.c_str(), option.start_ts.value(), option.end_ts.value()); std::string msg; // do not delete other index data option.enable_decode_value = false; + DLOG(INFO) << "pid " << cur_pid << " delete key " << option.DebugString(); if (auto status = client->Delete(table->GetId(), cur_pid, option, FLAGS_request_timeout_ms); !status.OK()) { return {base::ReturnCode::kDeleteFailed, - absl::StrCat("delete failed. key ", option.key, " pid ", cur_pid, " msg: ", status.GetMsg())}; + absl::StrCat("delete failed. key ", option.key, " pid ", cur_pid, " msg: ", status.GetMsg())}; } } @@ -1478,7 +1522,7 @@ void TabletImpl::Delete(RpcController* controller, const ::openmldb::api::Delete } auto table = GetTable(tid, pid); if (auto status = CheckTable(tid, pid, true, table); !status.OK()) { - SetResponseStatus(status, response); + SET_RESP_AND_WARN(response, status.GetCode(), status.GetMsg()); return; } auto replicator = GetReplicator(tid, pid); @@ -1547,13 +1591,14 @@ void TabletImpl::Delete(RpcController* controller, const ::openmldb::api::Delete } } } + DLOG(INFO) << tid << "." << pid << ": delete request " << request->ShortDebugString() << ", delete others " + << delete_others; auto aggrs = GetAggregators(tid, pid); if (!aggrs && !delete_others) { if (table->Delete(entry)) { - DEBUGLOG("delete ok. tid %u, pid %u, key %s", tid, pid, request->key().c_str()); + DLOG(INFO) << tid << "." << pid << ": delete ok, key " << request->key(); } else { - response->set_code(::openmldb::base::ReturnCode::kDeleteFailed); - response->set_msg("delete failed"); + SET_RESP_AND_WARN(response, base::ReturnCode::kDeleteFailed, "delete failed"); return; } } else { @@ -1585,36 +1630,37 @@ void TabletImpl::Delete(RpcController* controller, const ::openmldb::api::Delete } uint32_t pid_num = tablet_table_handler->GetPartitionNum(); auto table_client_manager = tablet_table_handler->GetTableClientManager(); + DLOG(INFO) << "delete from table & aggr " << tid << "." << pid; if (entry.dimensions_size() > 0) { const auto& dimension = entry.dimensions(0); uint32_t idx = dimension.idx(); auto index_def = table->GetIndex(idx); const auto& key = dimension.key(); if (delete_others) { - auto status = DeleteAllIndex(table, index_def, key, start_ts, end_ts, false, - table_client_manager, pid_num); + auto status = + DeleteAllIndex(table, index_def, key, start_ts, end_ts, false, table_client_manager, pid_num); if (!status.OK()) { SET_RESP_AND_WARN(response, status.GetCode(), status.GetMsg()); return; } } if (!table->Delete(idx, key, start_ts, end_ts)) { - response->set_code(::openmldb::base::ReturnCode::kDeleteFailed); - response->set_msg("delete failed"); + SET_RESP_AND_WARN(response, base::ReturnCode::kDeleteFailed, "delete from partition failed"); return; } auto aggr = get_aggregator(aggrs, idx); if (aggr) { if (!aggr->Delete(key, start_ts, end_ts)) { - PDLOG(WARNING, "delete from aggr failed. base table: tid[%u] pid[%u] index[%u] key[%s]. " - "aggr table: tid[%u]", + PDLOG(WARNING, + "delete from aggr failed. base table: tid[%u] pid[%u] index[%u] key[%s]. " + "aggr table: tid[%u]", tid, pid, idx, key.c_str(), aggr->GetAggrTid()); response->set_code(::openmldb::base::ReturnCode::kDeleteFailed); response->set_msg("delete from associated pre-aggr table failed"); return; } } - DEBUGLOG("delete ok. tid %u, pid %u, key %s", tid, pid, key.c_str()); + DLOG(INFO) << tid << "." << pid << ": table & agg delete ok, key " << key; } else { bool is_first_hit_index = true; for (const auto& index_def : table->GetAllIndex()) { @@ -1630,8 +1676,8 @@ void TabletImpl::Delete(RpcController* controller, const ::openmldb::api::Delete while (iter->Valid()) { auto pk = iter->GetPK(); if (delete_others && is_first_hit_index) { - auto status = DeleteAllIndex(table, index_def, pk, start_ts, end_ts, true, - table_client_manager, pid_num); + auto status = + DeleteAllIndex(table, index_def, pk, start_ts, end_ts, true, table_client_manager, pid_num); if (!status.OK()) { SET_RESP_AND_WARN(response, status.GetCode(), status.GetMsg()); return; @@ -1639,15 +1685,16 @@ void TabletImpl::Delete(RpcController* controller, const ::openmldb::api::Delete } iter->NextPK(); if (!table->Delete(idx, pk, start_ts, end_ts)) { - response->set_code(::openmldb::base::ReturnCode::kDeleteFailed); - response->set_msg("delete failed"); + SET_RESP_AND_WARN(response, base::ReturnCode::kDeleteFailed, "delete failed"); return; } auto aggr = get_aggregator(aggrs, idx); if (aggr) { if (!aggr->Delete(pk, start_ts, end_ts)) { - PDLOG(WARNING, "delete from aggr failed. base table: tid[%u] pid[%u] index[%u] key[%s]. " - "aggr table: tid[%u]", tid, pid, idx, pk.c_str(), aggr->GetAggrTid()); + PDLOG(WARNING, + "delete from aggr failed. base table: tid[%u] pid[%u] index[%u] key[%s]. " + "aggr table: tid[%u]", + tid, pid, idx, pk.c_str(), aggr->GetAggrTid()); response->set_code(::openmldb::base::ReturnCode::kDeleteFailed); response->set_msg("delete from associated pre-aggr table failed"); return; @@ -1656,11 +1703,11 @@ void TabletImpl::Delete(RpcController* controller, const ::openmldb::api::Delete } is_first_hit_index = false; } + DLOG(INFO) << tid << "." << pid << ": table & agg delete ok when no entry dim."; } } response->set_code(::openmldb::base::ReturnCode::kOk); response->set_msg("ok"); - replicator->AppendEntry(entry); if (FLAGS_binlog_notify_on_put) { replicator->Notify(); @@ -2563,7 +2610,7 @@ void TabletImpl::SetExpire(RpcController* controller, const ::openmldb::api::Set } void TabletImpl::MakeSnapshotInternal(uint32_t tid, uint32_t pid, uint64_t end_offset, - std::shared_ptr<::openmldb::api::TaskInfo> task, bool is_force) { + std::shared_ptr<::openmldb::api::TaskInfo> task, bool is_force) { PDLOG(INFO, "MakeSnapshotInternal begin, tid[%u] pid[%u]", tid, pid); std::shared_ptr
table; std::shared_ptr snapshot; @@ -3115,8 +3162,8 @@ void TabletImpl::LoadTable(RpcController* controller, const ::openmldb::api::Loa std::string db_path = GetDBPath(root_path, tid, pid); if (!::openmldb::base::IsExists(db_path)) { - PDLOG(WARNING, "table db path does not exist, but still load. tid %u, pid %u, path %s", - tid, pid, db_path.c_str()); + PDLOG(WARNING, "table db path does not exist, but still load. tid %u, pid %u, path %s", tid, pid, + db_path.c_str()); } std::shared_ptr
table = GetTable(tid, pid); @@ -3539,7 +3586,7 @@ void TabletImpl::CreateTable(RpcController* controller, const ::openmldb::api::C } void TabletImpl::TruncateTable(RpcController* controller, const ::openmldb::api::TruncateTableRequest* request, - ::openmldb::api::TruncateTableResponse* response, Closure* done) { + ::openmldb::api::TruncateTableResponse* response, Closure* done) { brpc::ClosureGuard done_guard(done); uint32_t tid = request->tid(); uint32_t pid = request->pid(); @@ -3552,8 +3599,8 @@ void TabletImpl::TruncateTable(RpcController* controller, const ::openmldb::api: for (const auto& aggr : *aggrs) { auto agg_table = aggr->GetAggTable(); if (!agg_table) { - PDLOG(WARNING, "aggrate table does not exist. tid[%u] pid[%u] index pos[%u]", - tid, pid, aggr->GetIndexPos()); + PDLOG(WARNING, "aggrate table does not exist. tid[%u] pid[%u] index pos[%u]", tid, pid, + aggr->GetIndexPos()); response->set_code(::openmldb::base::ReturnCode::kTableIsNotExist); response->set_msg("aggrate table does not exist"); return; @@ -3561,13 +3608,13 @@ void TabletImpl::TruncateTable(RpcController* controller, const ::openmldb::api: uint32_t agg_tid = agg_table->GetId(); uint32_t agg_pid = agg_table->GetPid(); if (auto status = TruncateTableInternal(agg_tid, agg_pid); !status.OK()) { - PDLOG(WARNING, "truncate aggrate table failed. tid[%u] pid[%u] index pos[%u]", - agg_tid, agg_pid, aggr->GetIndexPos()); + PDLOG(WARNING, "truncate aggrate table failed. tid[%u] pid[%u] index pos[%u]", agg_tid, agg_pid, + aggr->GetIndexPos()); base::SetResponseStatus(status, response); return; } - PDLOG(INFO, "truncate aggrate table success. tid[%u] pid[%u] index pos[%u]", - agg_tid, agg_pid, aggr->GetIndexPos()); + PDLOG(INFO, "truncate aggrate table success. tid[%u] pid[%u] index pos[%u]", agg_tid, agg_pid, + aggr->GetIndexPos()); } } response->set_code(::openmldb::base::ReturnCode::kOk); @@ -3620,8 +3667,8 @@ base::Status TabletImpl::TruncateTableInternal(uint32_t tid, uint32_t pid) { if (catalog_->AddTable(*table_meta, new_table)) { LOG(INFO) << "add table " << table_meta->name() << " to catalog with db " << table_meta->db(); } else { - LOG(WARNING) << "fail to add table " << table_meta->name() - << " to catalog with db " << table_meta->db(); + LOG(WARNING) << "fail to add table " << table_meta->name() << " to catalog with db " + << table_meta->db(); return {::openmldb::base::ReturnCode::kCatalogUpdateFailed, "fail to update catalog"}; } } @@ -3659,7 +3706,7 @@ void TabletImpl::ExecuteGc(RpcController* controller, const ::openmldb::api::Exe gc_pool_.AddTask(boost::bind(&TabletImpl::GcTable, this, tid, pid, true)); response->set_code(::openmldb::base::ReturnCode::kOk); response->set_msg("ok"); - PDLOG(INFO, "ExecuteGc. tid %u pid %u", tid, pid); + PDLOG(INFO, "ExecuteGc add task. tid %u pid %u", tid, pid); } void TabletImpl::GetTableFollower(RpcController* controller, const ::openmldb::api::GetTableFollowerRequest* request, @@ -4025,6 +4072,25 @@ int TabletImpl::UpdateTableMeta(const std::string& path, ::openmldb::api::TableM return UpdateTableMeta(path, table_meta, false); } +bool IsIOT(const ::openmldb::api::TableMeta* table_meta) { + auto cks = table_meta->column_key(); + if (cks.empty()) { + LOG(WARNING) << "no index in meta"; + return false; + } + if (cks[0].has_type() && cks[0].type() == common::IndexType::kClustered) { + // check other indexes + for (int i = 1; i < cks.size(); i++) { + if (cks[i].has_type() && cks[i].type() == common::IndexType::kClustered) { + LOG(WARNING) << "should be only one clustered index"; + return false; + } + } + return true; + } + return false; +} + int TabletImpl::CreateTableInternal(const ::openmldb::api::TableMeta* table_meta, std::string& msg) { uint32_t tid = table_meta->tid(); uint32_t pid = table_meta->pid(); @@ -4054,7 +4120,12 @@ int TabletImpl::CreateTableInternal(const ::openmldb::api::TableMeta* table_meta } std::string table_db_path = GetDBPath(db_root_path, tid, pid); if (table_meta->storage_mode() == openmldb::common::kMemory) { - table = std::make_shared(*table_meta); + if (IsIOT(table_meta)) { + LOG(INFO) << "create iot table " << tid << "." << pid; + table = std::make_shared(*table_meta, catalog_); + } else { + table = std::make_shared(*table_meta); + } } else { table = std::make_shared(*table_meta, table_db_path); } @@ -4292,7 +4363,16 @@ void TabletImpl::GcTable(uint32_t tid, uint32_t pid, bool execute_once) { std::shared_ptr
table = GetTable(tid, pid); if (table) { int32_t gc_interval = table->GetStorageMode() == common::kMemory ? FLAGS_gc_interval : FLAGS_disk_gc_interval; - table->SchedGc(); + if (auto iot = std::dynamic_pointer_cast(table); iot) { + sdk::SQLRouterOptions options; + options.zk_cluster = zk_cluster_; + options.zk_path = zk_path_; + auto router = sdk::NewClusterSQLRouter(options); + iot->SchedGCByDelete(router); // add a lock to avoid gc one table in the same time + } else { + table->SchedGc(); + } + if (!execute_once) { gc_pool_.DelayTask(gc_interval * 60 * 1000, boost::bind(&TabletImpl::GcTable, this, tid, pid, false)); } @@ -5228,12 +5308,12 @@ void TabletImpl::ExtractIndexData(RpcController* controller, const ::openmldb::a index_vec.push_back(cur_column_key); } if (IsClusterMode()) { - task_pool_.AddTask(boost::bind(&TabletImpl::ExtractIndexDataInternal, this, table, snapshot, - index_vec, request->partition_num(), request->offset(), request->dump_data(), + task_pool_.AddTask(boost::bind(&TabletImpl::ExtractIndexDataInternal, this, table, snapshot, index_vec, + request->partition_num(), request->offset(), request->dump_data(), task_ptr)); } else { - ExtractIndexDataInternal(table, snapshot, index_vec, request->partition_num(), request->offset(), - false, nullptr); + ExtractIndexDataInternal(table, snapshot, index_vec, request->partition_num(), request->offset(), false, + nullptr); } base::SetResponseOK(response); return; @@ -5850,9 +5930,10 @@ bool TabletImpl::CreateAggregatorInternal(const ::openmldb::api::CreateAggregato PDLOG(WARNING, "base table does not exist. tid %u, pid %u", base_meta.tid(), base_meta.pid()); return false; } - auto aggregator = ::openmldb::storage::CreateAggregator(base_meta, base_table, - *aggr_table->GetTableMeta(), aggr_table, aggr_replicator, request->index_pos(), request->aggr_col(), - request->aggr_func(), request->order_by_col(), request->bucket_size(), request->filter_col()); + auto aggregator = ::openmldb::storage::CreateAggregator( + base_meta, base_table, *aggr_table->GetTableMeta(), aggr_table, aggr_replicator, request->index_pos(), + request->aggr_col(), request->aggr_func(), request->order_by_col(), request->bucket_size(), + request->filter_col()); if (!aggregator) { msg.assign("create aggregator failed"); return false; @@ -5925,10 +6006,11 @@ TabletImpl::GetSystemTableIterator() { } auto schema = std::make_unique<::openmldb::codec::Schema>(); - + if (openmldb::schema::SchemaAdapter::ConvertSchema(*tablet_table_handler->GetSchema(), schema.get())) { std::map> tablet_clients = {{0, client}}; - return {{std::make_unique(tablet_table_handler->GetTid(), nullptr, tablet_clients), + return { + {std::make_unique(tablet_table_handler->GetTid(), nullptr, tablet_clients), std::move(schema)}}; } else { return std::nullopt; diff --git a/src/tablet/tablet_impl_test.cc b/src/tablet/tablet_impl_test.cc index 985c59e51d3..3df6a8e3553 100644 --- a/src/tablet/tablet_impl_test.cc +++ b/src/tablet/tablet_impl_test.cc @@ -6249,7 +6249,8 @@ TEST_F(TabletImplTest, DeleteRange) { ::openmldb::common::ColumnDesc* column_desc2 = table_meta->add_column_desc(); column_desc2->set_name("mcc"); column_desc2->set_data_type(::openmldb::type::kString); - SchemaCodec::SetIndex(table_meta->add_column_key(), "card", "card", "", ::openmldb::type::kAbsoluteTime, 120, 0); + // insert time ttl and 120 min, so data won't be gc by ttl + SchemaCodec::SetIndex(table_meta->add_column_key(), "card_idx", "card", "", ::openmldb::type::kAbsoluteTime, 120, 0); ::openmldb::api::CreateTableResponse response; tablet.CreateTable(NULL, &request, &response, &closure); @@ -6293,16 +6294,19 @@ TEST_F(TabletImplTest, DeleteRange) { delete_request.set_pid(1); delete_request.set_end_ts(1); tablet.Delete(NULL, &delete_request, &gen_response, &closure); - ASSERT_EQ(0, gen_response.code()); + ASSERT_EQ(0, gen_response.code()) << gen_response.ShortDebugString(); ::openmldb::api::ExecuteGcRequest e_request; e_request.set_tid(id); e_request.set_pid(1); + // async task, need to wait + // segment: entries -> node cache tablet.ExecuteGc(NULL, &e_request, &gen_response, &closure); + ASSERT_EQ(0, gen_response.code()) << gen_response.ShortDebugString(); sleep(2); + assert_status(100, 3400, 5786); // before node cache gc, status will be the same + // gc node cache tablet.ExecuteGc(NULL, &e_request, &gen_response, &closure); - sleep(2); - assert_status(0, 0, 1626); - tablet.ExecuteGc(NULL, &e_request, &gen_response, &closure); + ASSERT_EQ(0, gen_response.code()) << gen_response.ShortDebugString(); sleep(2); assert_status(0, 0, 0); } diff --git a/steps/test_python.sh b/steps/test_python.sh index 8c366f77b0c..3e3588b0db7 100644 --- a/steps/test_python.sh +++ b/steps/test_python.sh @@ -42,8 +42,8 @@ python3 -m pip install "${whl_name_sdk}[test]" cd "${ROOT_DIR}"/python/openmldb_tool/dist/ whl_name_tool=$(ls openmldb*.whl) echo "whl_name_tool:${whl_name_tool}" -# pip 23.1.2 just needs to install test(rpc is required by test) -python3 -m pip install "${whl_name_tool}[rpc,test]" +# pip 23.1.2 just needs to install test +python3 -m pip install "${whl_name_tool}[test]" python3 -m pip install pytest-cov diff --git a/tools/tool.py b/tools/tool.py index b95a6246fc5..4f92f2a4098 100644 --- a/tools/tool.py +++ b/tools/tool.py @@ -219,7 +219,7 @@ def GetTableInfoHTTP(self, database, table_name = ''): ns = self.endpoint_map[self.ns_leader] conn = httplib.HTTPConnection(ns) param = {"db": database, "name": table_name} - headers = {"Content-type": "application/json"} + headers = {"Content-type": "application/json", "Authorization": "foo"} conn.request("POST", "/NameServer/ShowTable", json.dumps(param), headers) response = conn.getresponse() if response.status != 200: @@ -233,13 +233,15 @@ def GetTableInfoHTTP(self, database, table_name = ''): def ParseTableInfo(self, table_info): result = {} + if not table_info: + return Status(-1, "table info is empty"), None for record in table_info: is_leader = True if record[4] == "leader" else False is_alive = True if record[5] == "yes" else False partition = Partition(record[0], record[1], record[2], record[3], is_leader, is_alive, record[6]) result.setdefault(record[2], []) result[record[2]].append(partition) - return result + return Status(), result def ParseTableInfoJson(self, table_info): """parse one table's partition info from json""" @@ -260,8 +262,7 @@ def GetTablePartition(self, database, table_name): status, result = self.GetTableInfo(database, table_name) if not status.OK: return status, None - partition_dict = self.ParseTableInfo(result) - return Status(), partition_dict + return self.ParseTableInfo(result) def GetAllTable(self, database): status, result = self.GetTableInfo(database) @@ -323,7 +324,7 @@ def LoadTableHTTP(self, endpoint, name, tid, pid, storage): # ttl won't effect, set to 0, and seg cnt is always 8 # and no matter if leader param = {"table_meta": {"name": name, "tid": tid, "pid": pid, "ttl":0, "seg_cnt":8, "storage_mode": storage}} - headers = {"Content-type": "application/json"} + headers = {"Content-type": "application/json", "Authorization": "foo"} conn.request("POST", "/TabletServer/LoadTable", json.dumps(param), headers) response = conn.getresponse() if response.status != 200: