From 9d58b46436d91dad42cb67f1bfc72765ca80a278 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Tue, 21 May 2024 08:47:57 +0000 Subject: [PATCH 1/5] feat(parser): simple ANSI SQL rewriter --- hybridse/src/CMakeLists.txt | 1 + hybridse/src/rewriter/ast_rewriter.cc | 279 +++++++++++++++++++++ hybridse/src/rewriter/ast_rewriter.h | 32 +++ hybridse/src/rewriter/ast_rewriter_test.cc | 67 +++++ 4 files changed, 379 insertions(+) create mode 100644 hybridse/src/rewriter/ast_rewriter.cc create mode 100644 hybridse/src/rewriter/ast_rewriter.h create mode 100644 hybridse/src/rewriter/ast_rewriter_test.cc diff --git a/hybridse/src/CMakeLists.txt b/hybridse/src/CMakeLists.txt index 80c5cc2a5a3..4f25d87ab70 100644 --- a/hybridse/src/CMakeLists.txt +++ b/hybridse/src/CMakeLists.txt @@ -48,6 +48,7 @@ hybridse_add_src_and_tests(vm) hybridse_add_src_and_tests(codec) hybridse_add_src_and_tests(case) hybridse_add_src_and_tests(passes) +hybridse_add_src_and_tests(rewriter) get_property(SRC_FILE_LIST_STR GLOBAL PROPERTY PROP_SRC_FILE_LIST) string(REPLACE " " ";" SRC_FILE_LIST ${SRC_FILE_LIST_STR}) diff --git a/hybridse/src/rewriter/ast_rewriter.cc b/hybridse/src/rewriter/ast_rewriter.cc new file mode 100644 index 00000000000..f0b77c75bde --- /dev/null +++ b/hybridse/src/rewriter/ast_rewriter.cc @@ -0,0 +1,279 @@ +/** + * Copyright (c) 2024 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "rewriter/ast_rewriter.h" + +#include + +#include "plan/plan_api.h" +#include "zetasql/parser/parse_tree_manual.h" +#include "zetasql/parser/parser.h" +#include "zetasql/parser/unparser.h" + +namespace hybridse { +namespace rewriter { + +// unparser that make some rewrites so outputed SQL is +// compatible with ANSI SQL as much as can +class ANSISQLRewriteUnparser : public zetasql::parser::Unparser { + public: + explicit ANSISQLRewriteUnparser(std::string* unparsed) : zetasql::parser::Unparser(unparsed) {} + ~ANSISQLRewriteUnparser() override {} + ANSISQLRewriteUnparser(const ANSISQLRewriteUnparser&) = delete; + ANSISQLRewriteUnparser& operator=(const ANSISQLRewriteUnparser&) = delete; + + void visitASTSelect(const zetasql::ASTSelect* node, void* data) override { + while (true) { + absl::string_view filter_col; + + // 1. filter condition is 'col = 1' + if (node->where_clause() != nullptr && + node->where_clause()->expression()->node_kind() == zetasql::AST_BINARY_EXPRESSION) { + auto expr = node->where_clause()->expression()->GetAsOrNull(); + if (expr && expr->op() == zetasql::ASTBinaryExpression::Op::EQ && !expr->is_not()) { + { + auto lval = expr->lhs()->GetAsOrNull(); + auto rval = expr->rhs()->GetAsOrNull(); + if (lval && rval && lval->image() == "1") { + // TODO(someone): + // 1. consider lval->iamge() as '1L' + // 2. consider rval as . + filter_col = rval->last_name()->GetAsStringView(); + } + } + if (filter_col.empty()) { + auto lval = expr->rhs()->GetAsOrNull(); + auto rval = expr->lhs()->GetAsOrNull(); + if (lval && rval && lval->image() == "1") { + // TODO(someone): + // 1. consider lval->iamge() as '1L' + // 2. consider rval as . + filter_col = rval->last_name()->GetAsStringView(); + } + } + } + } + + // 2. FROM a subquery: SELECT ... t1 LEFT JOIN t2 WINDOW + const zetasql::ASTPathExpression* join_lhs_key = nullptr; + const zetasql::ASTPathExpression* join_rhs_key = nullptr; + if (node->from_clause() == nullptr) { + break; + } + auto sub = node->from_clause()->table_expression()->GetAsOrNull(); + if (!sub) { + break; + } + auto subquery = sub->subquery(); + if (subquery->with_clause() != nullptr || subquery->order_by() != nullptr || + subquery->limit_offset() != nullptr) { + break; + } + + auto select = subquery->query_expr()->GetAsOrNull(); + // select have window + if (select->window_clause() == nullptr || select->from_clause() == nullptr) { + break; + } + + // 3. CHECK FROM CLAUSE: must 't1 LEFT JOIN t2 on t1.key = t2.key' + auto join = select->from_clause()->table_expression()->GetAsOrNull(); + if (join == nullptr || join->join_type() != zetasql::ASTJoin::LEFT || join->on_clause() == nullptr) { + break; + } + auto on_expr = join->on_clause()->expression()->GetAsOrNull(); + if (on_expr == nullptr || on_expr->is_not() || on_expr->op() != zetasql::ASTBinaryExpression::EQ) { + break; + } + + // still might null + join_lhs_key = on_expr->lhs()->GetAsOrNull(); + join_rhs_key = on_expr->rhs()->GetAsOrNull(); + if (join_lhs_key == nullptr || join_rhs_key == nullptr) { + break; + } + + // 3. CHECK row_id is row_number() over w FROM select_list + bool found = false; + absl::string_view window_name; + for (auto col : select->select_list()->columns()) { + if (col->alias() && col->alias()->GetAsStringView() == filter_col) { + auto agg_func = col->expression()->GetAsOrNull(); + if (!agg_func || !agg_func->function()) { + break; + } + + auto w = agg_func->window_spec(); + if (!w || w->base_window_name() == nullptr) { + break; + } + window_name = w->base_window_name()->GetAsStringView(); + + auto ph = agg_func->function()->function(); + if (ph->num_names() == 1 && + absl::AsciiStrToLower(ph->first_name()->GetAsStringView()) == "row_number") { + found = true; + } + } + } + if (!found || window_name.empty()) { + break; + } + + // 4. CHECK WINDOW CLAUSE + { + if (select->window_clause()->windows().size() != 1) { + // targeting single window only + break; + } + auto win = select->window_clause()->windows().front(); + if (win->name()->GetAsStringView() != window_name) { + break; + } + auto spec = win->window_spec(); + if (spec->window_frame() != nullptr || spec->partition_by() == nullptr || spec->order_by() == nullptr) { + // TODO(someone): allow unbounded window frame + break; + } + + // PARTITION BY contains join_lhs_key + // ORDER BY is join_rhs_key + bool partition_meet = false; + for (auto expr : spec->partition_by()->partitioning_expressions()) { + auto e = expr->GetAsOrNull(); + if (e) { + if (e->last_name()->GetAsStringView() == join_lhs_key->last_name()->GetAsStringView()) { + partition_meet = true; + } + } + } + + if (!partition_meet) { + break; + } + + if (spec->order_by()->ordering_expressions().size() != 1) { + break; + } + + if (spec->order_by()->ordering_expressions().front()->ordering_spec() != + zetasql::ASTOrderingExpression::DESC) { + break; + } + + auto e = spec->order_by() + ->ordering_expressions() + .front() + ->expression() + ->GetAsOrNull(); + if (!e) { + break; + } + + // rewrite + { + PrintOpenParenIfNeeded(node); + println(); + print("SELECT"); + if (node->hint() != nullptr) { + node->hint()->Accept(this, data); + } + if (node->anonymization_options() != nullptr) { + print("WITH ANONYMIZATION OPTIONS"); + node->anonymization_options()->Accept(this, data); + } + if (node->distinct()) { + print("DISTINCT"); + } + + // Visit all children except hint() and anonymization_options, which we + // processed above. We can't just use visitASTChildren(node, data) because + // we need to insert the DISTINCT modifier after the hint and anonymization + // nodes and before everything else. + for (int i = 0; i < node->num_children(); ++i) { + const zetasql::ASTNode* child = node->child(i); + if (child == node->from_clause()) { + // this from subquery will simplified to join + println(); + print("FROM"); + visitASTJoinRewrited(join, e, data); + } else if (child != node->hint() && child != node->anonymization_options() && + child != node->where_clause()) { + child->Accept(this, data); + } + } + + println(); + PrintCloseParenIfNeeded(node); + return; + } + } + + break; + } + + zetasql::parser::Unparser::visitASTSelect(node, data); + } + + void visitASTJoinRewrited(const zetasql::ASTJoin* node, const zetasql::ASTPathExpression* order, void* data) { + node->child(0)->Accept(this, data); + + if (node->join_type() == zetasql::ASTJoin::COMMA) { + print(","); + } else { + println(); + if (node->natural()) { + print("NATURAL"); + } + print("LAST"); + print(node->GetSQLForJoinHint()); + + print("JOIN"); + } + println(); + + // This will print hints, the rhs, and the ON or USING clause. + for (int i = 1; i < node->num_children(); i++) { + node->child(i)->Accept(this, data); + if (node->child(i) == node->rhs() && order) { + // optional order by after rhs + print("ORDER BY"); + order->Accept(this, data); + } + } + } +}; + +absl::StatusOr Rewrite(absl::string_view query) { + std::unique_ptr ast; + auto s = hybridse::plan::ParseStatement(query, &ast); + if (!s.ok()) { + return s; + } + + if (ast->statement() && ast->statement()->node_kind() == zetasql::AST_QUERY_STATEMENT) { + std::string unparsed_; + ANSISQLRewriteUnparser unparser(&unparsed_); + ast->statement()->Accept(&unparser, nullptr); + unparser.FlushLine(); + return unparsed_; + } + + return std::string(query); +} + +} // namespace rewriter +} // namespace hybridse diff --git a/hybridse/src/rewriter/ast_rewriter.h b/hybridse/src/rewriter/ast_rewriter.h new file mode 100644 index 00000000000..17ea7ad0d04 --- /dev/null +++ b/hybridse/src/rewriter/ast_rewriter.h @@ -0,0 +1,32 @@ +/** + * Copyright (c) 2024 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HYBRIDSE_SRC_REWRITER_AST_REWRITER_H_ +#define HYBRIDSE_SRC_REWRITER_AST_REWRITER_H_ + +#include + +#include "absl/status/statusor.h" + +namespace hybridse { +namespace rewriter { + +absl::StatusOr Rewrite(absl::string_view query); + +} // namespace rewriter +} // namespace hybridse + +#endif // HYBRIDSE_SRC_REWRITER_AST_REWRITER_H_ diff --git a/hybridse/src/rewriter/ast_rewriter_test.cc b/hybridse/src/rewriter/ast_rewriter_test.cc new file mode 100644 index 00000000000..38e0c7b3115 --- /dev/null +++ b/hybridse/src/rewriter/ast_rewriter_test.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2024 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "rewriter/ast_rewriter.h" + +#include +#include + +#include "gtest/gtest.h" +#include "plan/plan_api.h" +#include "zetasql/parser/parser.h" + +namespace hybridse { +namespace rewriter { + +class ASTRewriterTest : public ::testing::Test {}; + +TEST_F(ASTRewriterTest, LastJoin) { + std::string str = R"( +SELECT id, val, k, ts, idr, valr FROM ( + SELECT t1.*, t2.id as idr, t2.val as valr, row_number() over w as row_id + FROM t1 LEFT JOIN t2 ON t1.k = t2.k + WINDOW w as (PARTITION BY t1.id,t1.k order by t2.ts desc) +) t WHERE row_id = 1)"; + + auto s = hybridse::rewriter::Rewrite(str); + ASSERT_TRUE(s.ok()) << s.status(); + + ASSERT_EQ(R"(SELECT + id, + val, + k, + ts, + idr, + valr +FROM t1 +LAST JOIN +t2 ORDER BY t2.ts +ON t1.k = t2.k +)", + s.value()); + + std::unique_ptr out; + auto ss = ::hybridse::plan::ParseStatement(s.value(), &out); + ASSERT_TRUE(ss.ok()) << ss; +} + +} // namespace rewriter +} // namespace hybridse + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From b3f896e3c394bb1cc8109c2aefc97af2bfb1ff34 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Wed, 22 May 2024 11:22:00 +0000 Subject: [PATCH 2/5] feat(draft): translate request mode query --- hybridse/src/rewriter/ast_rewriter.cc | 21 +++++++++++++++------ hybridse/src/rewriter/ast_rewriter_test.cc | 4 ++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/hybridse/src/rewriter/ast_rewriter.cc b/hybridse/src/rewriter/ast_rewriter.cc index f0b77c75bde..27695fb852c 100644 --- a/hybridse/src/rewriter/ast_rewriter.cc +++ b/hybridse/src/rewriter/ast_rewriter.cc @@ -28,12 +28,12 @@ namespace rewriter { // unparser that make some rewrites so outputed SQL is // compatible with ANSI SQL as much as can -class ANSISQLRewriteUnparser : public zetasql::parser::Unparser { +class LastJoinRewriteUnparser : public zetasql::parser::Unparser { public: - explicit ANSISQLRewriteUnparser(std::string* unparsed) : zetasql::parser::Unparser(unparsed) {} - ~ANSISQLRewriteUnparser() override {} - ANSISQLRewriteUnparser(const ANSISQLRewriteUnparser&) = delete; - ANSISQLRewriteUnparser& operator=(const ANSISQLRewriteUnparser&) = delete; + explicit LastJoinRewriteUnparser(std::string* unparsed) : zetasql::parser::Unparser(unparsed) {} + ~LastJoinRewriteUnparser() override {} + LastJoinRewriteUnparser(const LastJoinRewriteUnparser&) = delete; + LastJoinRewriteUnparser& operator=(const LastJoinRewriteUnparser&) = delete; void visitASTSelect(const zetasql::ASTSelect* node, void* data) override { while (true) { @@ -126,6 +126,7 @@ class ANSISQLRewriteUnparser : public zetasql::parser::Unparser { if (ph->num_names() == 1 && absl::AsciiStrToLower(ph->first_name()->GetAsStringView()) == "row_number") { found = true; + break; } } } @@ -257,6 +258,14 @@ class ANSISQLRewriteUnparser : public zetasql::parser::Unparser { } }; +class RequestQueryRewriteUnparser : public zetasql::parser::Unparser { + public: + explicit RequestQueryRewriteUnparser(std::string* unparsed) : zetasql::parser::Unparser(unparsed) {} + ~RequestQueryRewriteUnparser() override {} + RequestQueryRewriteUnparser(const RequestQueryRewriteUnparser&) = delete; + RequestQueryRewriteUnparser& operator=(const RequestQueryRewriteUnparser&) = delete; +}; + absl::StatusOr Rewrite(absl::string_view query) { std::unique_ptr ast; auto s = hybridse::plan::ParseStatement(query, &ast); @@ -266,7 +275,7 @@ absl::StatusOr Rewrite(absl::string_view query) { if (ast->statement() && ast->statement()->node_kind() == zetasql::AST_QUERY_STATEMENT) { std::string unparsed_; - ANSISQLRewriteUnparser unparser(&unparsed_); + LastJoinRewriteUnparser unparser(&unparsed_); ast->statement()->Accept(&unparser, nullptr); unparser.FlushLine(); return unparsed_; diff --git a/hybridse/src/rewriter/ast_rewriter_test.cc b/hybridse/src/rewriter/ast_rewriter_test.cc index 38e0c7b3115..fbd78098118 100644 --- a/hybridse/src/rewriter/ast_rewriter_test.cc +++ b/hybridse/src/rewriter/ast_rewriter_test.cc @@ -31,10 +31,10 @@ class ASTRewriterTest : public ::testing::Test {}; TEST_F(ASTRewriterTest, LastJoin) { std::string str = R"( SELECT id, val, k, ts, idr, valr FROM ( - SELECT t1.*, t2.id as idr, t2.val as valr, row_number() over w as row_id + SELECT t1.*, t2.id as idr, t2.val as valr, row_number() over w as any_id FROM t1 LEFT JOIN t2 ON t1.k = t2.k WINDOW w as (PARTITION BY t1.id,t1.k order by t2.ts desc) -) t WHERE row_id = 1)"; +) t WHERE any_id = 1)"; auto s = hybridse::rewriter::Rewrite(str); ASSERT_TRUE(s.ok()) << s.status(); From c748f58b52c17840efba795546cb43c8dedcd1d8 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Thu, 23 May 2024 14:37:05 +0000 Subject: [PATCH 3/5] feat: request query rewriter --- hybridse/src/rewriter/ast_rewriter.cc | 249 ++++++++++++++++++++- hybridse/src/rewriter/ast_rewriter_test.cc | 57 ++++- 2 files changed, 291 insertions(+), 15 deletions(-) diff --git a/hybridse/src/rewriter/ast_rewriter.cc b/hybridse/src/rewriter/ast_rewriter.cc index 27695fb852c..de36b6d3110 100644 --- a/hybridse/src/rewriter/ast_rewriter.cc +++ b/hybridse/src/rewriter/ast_rewriter.cc @@ -17,6 +17,7 @@ #include "rewriter/ast_rewriter.h" #include +#include #include "plan/plan_api.h" #include "zetasql/parser/parse_tree_manual.h" @@ -84,12 +85,18 @@ class LastJoinRewriteUnparser : public zetasql::parser::Unparser { } auto select = subquery->query_expr()->GetAsOrNull(); + if (!select) { + break; + } // select have window if (select->window_clause() == nullptr || select->from_clause() == nullptr) { break; } // 3. CHECK FROM CLAUSE: must 't1 LEFT JOIN t2 on t1.key = t2.key' + if (!select->from_clause()) { + break; + } auto join = select->from_clause()->table_expression()->GetAsOrNull(); if (join == nullptr || join->join_type() != zetasql::ASTJoin::LEFT || join->on_clause() == nullptr) { break; @@ -258,30 +265,250 @@ class LastJoinRewriteUnparser : public zetasql::parser::Unparser { } }; +// SELECT: +// WHERE col = 0 +// FROM (subquery): +// subquery is UNION ALL, or contains left-most query is UNION ALL +// and UNION ALL is select const ..., 0 as col UNION ALL (select .., 1 as col table) class RequestQueryRewriteUnparser : public zetasql::parser::Unparser { public: explicit RequestQueryRewriteUnparser(std::string* unparsed) : zetasql::parser::Unparser(unparsed) {} ~RequestQueryRewriteUnparser() override {} RequestQueryRewriteUnparser(const RequestQueryRewriteUnparser&) = delete; RequestQueryRewriteUnparser& operator=(const RequestQueryRewriteUnparser&) = delete; + + void visitASTSelect(const zetasql::ASTSelect* node, void* data) override { + while (true) { + if (outer_most_select_ != nullptr) { + break; + } + + outer_most_select_ = node; + if (node->where_clause() == nullptr) { + break; + } + absl::string_view filter_col; + const zetasql::ASTExpression* filter_expr; + + // 1. filter condition is 'col = 0' + if (node->where_clause()->expression()->node_kind() != zetasql::AST_BINARY_EXPRESSION) { + break; + } + auto expr = node->where_clause()->expression()->GetAsOrNull(); + if (!expr || expr->op() != zetasql::ASTBinaryExpression::Op::EQ || expr->is_not()) { + break; + } + { + auto rval = expr->rhs()->GetAsOrNull(); + if (rval) { + // TODO(someone): + // 2. consider rval as . + filter_col = rval->last_name()->GetAsStringView(); + filter_expr = expr->lhs(); + } + } + if (filter_col.empty()) { + auto rval = expr->lhs()->GetAsOrNull(); + if (rval) { + // TODO(someone): + // 2. consider rval as . + filter_col = rval->last_name()->GetAsStringView(); + filter_expr = expr->rhs(); + } + } + if (filter_col.empty() || !filter_expr) { + break; + } + + if (node->from_clause() == nullptr) { + break; + } + auto sub = node->from_clause()->table_expression()->GetAsOrNull(); + if (!sub) { + break; + } + auto subquery = sub->subquery(); + + findUnionAllForQuery(subquery, filter_col, filter_expr, node->where_clause()); + + break; // fallback normal + } + + zetasql::parser::Unparser::visitASTSelect(node, data); + } + + void visitASTSetOperation(const zetasql::ASTSetOperation* node, void* data) override { + if (node == detected_request_block_) { + node->inputs().back()->Accept(this, data); + } else { + zetasql::parser::Unparser::visitASTSetOperation(node, data); + } + } + + void visitASTQueryStatement(const zetasql::ASTQueryStatement* node, void* data) override { + visitASTQuery(node->query(), data); + if (!list_.empty()) { + constSelectListAsConfigClause(list_, data); + } + } + + void visitASTWhereClause(const zetasql::ASTWhereClause* node, void* data) override { + if (node != filter_clause_) { + zetasql::parser::Unparser::visitASTWhereClause(node, data); + } + } + + private: + void findUnionAllForQuery(const zetasql::ASTQuery* query, absl::string_view label_name, + const zetasql::ASTExpression* filter_expr, const zetasql::ASTWhereClause* filter) { + if (!query) { + return; + } + auto qe = query->query_expr(); + switch (qe->node_kind()) { + case zetasql::AST_SET_OPERATION: { + auto set = qe->GetAsOrNull(); + if (set && set->op_type() == zetasql::ASTSetOperation::UNION && set->distinct() == false && + set->hint() == nullptr && set->inputs().size() == 2) { + [[maybe_unused]] bool ret = + findUnionAllInput(set->inputs().at(0), set->inputs().at(1), label_name, filter_expr, filter) || + findUnionAllInput(set->inputs().at(0), set->inputs().at(1), label_name, filter_expr, filter); + if (ret) { + detected_request_block_ = set; + } + } + break; + } + case zetasql::AST_QUERY: { + findUnionAllForQuery(qe->GetAsOrNull(), label_name, filter_expr, filter); + break; + } + case zetasql::AST_SELECT: { + auto select = qe->GetAsOrNull(); + if (select->from_clause() && + select->from_clause()->table_expression()->node_kind() == zetasql::AST_TABLE_SUBQUERY) { + auto sub = select->from_clause()->table_expression()->GetAsOrNull(); + if (sub && sub->subquery()) { + findUnionAllForQuery(sub->subquery(), label_name, filter_expr, filter); + } + } + break; + } + default: + break; + } + } + + void constSelectListAsConfigClause(const std::vector& selects, void* data) { + print("CONFIG (execute_mode = 'request', values = ("); + for (int i = 0; i < selects.size(); ++i) { + selects.at(i)->Accept(this, data); + if (i + 1 < selects.size()) { + print(","); + } + } + print(") )"); + } + + bool findUnionAllInput(const zetasql::ASTQueryExpression* lhs, const zetasql::ASTQueryExpression* rhs, + absl::string_view label_name, const zetasql::ASTExpression* filter_expr, + const zetasql::ASTWhereClause* filter) { + // lhs is select const + label_name of value 0 + auto lselect = lhs->GetAsOrNull(); + if (!lselect || lselect->num_children() > 1) { + // only select_list required, otherwise size > 1 + return false; + } + + bool has_label_col_0 = false; + const zetasql::ASTExpression* label_expr_0 = nullptr; + std::vector vec; + for (auto col : lselect->select_list()->columns()) { + if (col->alias() && col->alias()->GetAsStringView() == label_name) { + has_label_col_0 = true; + label_expr_0 = col->expression(); + } else { + vec.push_back(col->expression()); + } + } + + // rhs is simple selects from table + label_name of value 1 + auto rselect = rhs->GetAsOrNull(); + if (!rselect || rselect->num_children() > 2 || !rselect->from_clause()) { + // only select_list + from_clause required + return false; + } + if (rselect->from_clause()->table_expression()->node_kind() != zetasql::AST_TABLE_PATH_EXPRESSION) { + return false; + } + + bool has_label_col_1 = false; + const zetasql::ASTExpression* label_expr_1 = nullptr; + for (auto col : rselect->select_list()->columns()) { + if (col->alias() && col->alias()->GetAsStringView() == label_name) { + has_label_col_1 = true; + label_expr_1 = col->expression(); + } + } + + LOG(INFO) << "label expr 0: " << label_expr_0->SingleNodeDebugString(); + LOG(INFO) << "label expr 1: " << label_expr_1->SingleNodeDebugString(); + LOG(INFO) << "filter expr: " << filter_expr->SingleNodeDebugString(); + + if (has_label_col_0 && has_label_col_1 && + label_expr_0->SingleNodeDebugString() != label_expr_1->SingleNodeDebugString() && + label_expr_0->SingleNodeDebugString() == filter_expr->SingleNodeDebugString()) { + list_ = vec; + filter_clause_ = filter; + return true; + } + + return false; + } + + private: + const zetasql::ASTSelect* outer_most_select_ = nullptr; + // detected request query block, set by when visiting outer most query + const zetasql::ASTSetOperation* detected_request_block_ = nullptr; + const zetasql::ASTWhereClause* filter_clause_; + + std::vector list_; }; absl::StatusOr Rewrite(absl::string_view query) { - std::unique_ptr ast; - auto s = hybridse::plan::ParseStatement(query, &ast); - if (!s.ok()) { - return s; + auto str = std::string(query); + { + std::unique_ptr ast; + auto s = hybridse::plan::ParseStatement(str, &ast); + if (!s.ok()) { + return s; + } + + if (ast->statement() && ast->statement()->node_kind() == zetasql::AST_QUERY_STATEMENT) { + std::string unparsed_; + LastJoinRewriteUnparser unparser(&unparsed_); + ast->statement()->Accept(&unparser, nullptr); + unparser.FlushLine(); + str = unparsed_; + } } + { + std::unique_ptr ast; + auto s = hybridse::plan::ParseStatement(str, &ast); + if (!s.ok()) { + return s; + } - if (ast->statement() && ast->statement()->node_kind() == zetasql::AST_QUERY_STATEMENT) { - std::string unparsed_; - LastJoinRewriteUnparser unparser(&unparsed_); - ast->statement()->Accept(&unparser, nullptr); - unparser.FlushLine(); - return unparsed_; + if (ast->statement() && ast->statement()->node_kind() == zetasql::AST_QUERY_STATEMENT) { + std::string unparsed_; + RequestQueryRewriteUnparser unparser(&unparsed_); + ast->statement()->Accept(&unparser, nullptr); + unparser.FlushLine(); + str = unparsed_; + } } - return std::string(query); + return str; } } // namespace rewriter diff --git a/hybridse/src/rewriter/ast_rewriter_test.cc b/hybridse/src/rewriter/ast_rewriter_test.cc index fbd78098118..0ed554fcf62 100644 --- a/hybridse/src/rewriter/ast_rewriter_test.cc +++ b/hybridse/src/rewriter/ast_rewriter_test.cc @@ -46,10 +46,12 @@ SELECT id, val, k, ts, idr, valr FROM ( ts, idr, valr -FROM t1 -LAST JOIN -t2 ORDER BY t2.ts -ON t1.k = t2.k +FROM + t1 + LAST JOIN + t2 + ORDER BY t2.ts + ON t1.k = t2.k )", s.value()); @@ -58,6 +60,53 @@ ON t1.k = t2.k ASSERT_TRUE(ss.ok()) << ss; } +TEST_F(ASTRewriterTest, RequestQuery) { + std::string str = R"( +SELECT id, k, agg +FROM ( + SELECT id, k, label, count(val) over w as agg + FROM ( + SELECT 6 as id, "xxx" as val, 10 as k, 9000 as ts, 0 as label + UNION ALL + SELECT *, 1 as label FROM t1 + ) t + WINDOW w as (PARTITION BY k ORDER BY ts rows between unbounded preceding and current row) +) t WHERE label = 0)"; + + auto s = hybridse::rewriter::Rewrite(str); + ASSERT_TRUE(s.ok()) << s.status(); + + ASSERT_EQ(R"s(SELECT + id, + k, + agg +FROM + ( + SELECT + id, + k, + label, + count(val) OVER (w) AS agg + FROM + ( + SELECT + *, + 1 AS label + FROM + t1 + ) AS t + WINDOW w AS (PARTITION BY k + ORDER BY ts ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + ) AS t +CONFIG (execute_mode = 'request', values = (6, "xxx", 10, 9000) ) +)s", + s.value()); + + std::unique_ptr out; + auto ss = ::hybridse::plan::ParseStatement(s.value(), &out); + ASSERT_TRUE(ss.ok()) << ss; +} + } // namespace rewriter } // namespace hybridse From 7b611f81fc078f2fb9f7def0a1b5f0237eb6672c Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Fri, 24 May 2024 09:31:56 +0000 Subject: [PATCH 4/5] test: tpc rewrite cases --- hybridse/src/rewriter/ast_rewriter_test.cc | 176 +++++++++++++++++---- 1 file changed, 144 insertions(+), 32 deletions(-) diff --git a/hybridse/src/rewriter/ast_rewriter_test.cc b/hybridse/src/rewriter/ast_rewriter_test.cc index 0ed554fcf62..4dcd4076dac 100644 --- a/hybridse/src/rewriter/ast_rewriter_test.cc +++ b/hybridse/src/rewriter/ast_rewriter_test.cc @@ -17,8 +17,9 @@ #include "rewriter/ast_rewriter.h" #include -#include +#include +#include "absl/strings/ascii.h" #include "gtest/gtest.h" #include "plan/plan_api.h" #include "zetasql/parser/parser.h" @@ -26,20 +27,22 @@ namespace hybridse { namespace rewriter { -class ASTRewriterTest : public ::testing::Test {}; +struct Case { + absl::string_view in; + absl::string_view out; +}; -TEST_F(ASTRewriterTest, LastJoin) { - std::string str = R"( -SELECT id, val, k, ts, idr, valr FROM ( - SELECT t1.*, t2.id as idr, t2.val as valr, row_number() over w as any_id - FROM t1 LEFT JOIN t2 ON t1.k = t2.k - WINDOW w as (PARTITION BY t1.id,t1.k order by t2.ts desc) -) t WHERE any_id = 1)"; +class ASTRewriterTest : public ::testing::TestWithParam {}; - auto s = hybridse::rewriter::Rewrite(str); - ASSERT_TRUE(s.ok()) << s.status(); - - ASSERT_EQ(R"(SELECT +std::vector strip_cases = { + {R"s( + SELECT id, val, k, ts, idr, valr FROM ( + SELECT t1.*, t2.id as idr, t2.val as valr, row_number() over w as any_id + FROM t1 LEFT JOIN t2 ON t1.k = t2.k + WINDOW w as (PARTITION BY t1.id,t1.k order by t2.ts desc) + ) t WHERE any_id = 1)s", + R"( +SELECT id, val, k, @@ -51,17 +54,8 @@ FROM LAST JOIN t2 ORDER BY t2.ts - ON t1.k = t2.k -)", - s.value()); - - std::unique_ptr out; - auto ss = ::hybridse::plan::ParseStatement(s.value(), &out); - ASSERT_TRUE(ss.ok()) << ss; -} - -TEST_F(ASTRewriterTest, RequestQuery) { - std::string str = R"( + ON t1.k = t2.k)"}, + {R"( SELECT id, k, agg FROM ( SELECT id, k, label, count(val) over w as agg @@ -71,12 +65,9 @@ FROM ( SELECT *, 1 as label FROM t1 ) t WINDOW w as (PARTITION BY k ORDER BY ts rows between unbounded preceding and current row) -) t WHERE label = 0)"; - - auto s = hybridse::rewriter::Rewrite(str); - ASSERT_TRUE(s.ok()) << s.status(); - - ASSERT_EQ(R"s(SELECT +) t WHERE label = 0)", + R"( +SELECT id, k, agg @@ -99,8 +90,129 @@ FROM ORDER BY ts ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) ) AS t CONFIG (execute_mode = 'request', values = (6, "xxx", 10, 9000) ) -)s", - s.value()); +)"}, + // simplist request query + {R"s( + SELECT id, k + FROM ( + SELECT 6 as id, "xxx" as val, 10 as k, 9000 as ts, 0 as label + UNION ALL + SELECT *, 1 as label FROM t1 + ) t WHERE label = 0)s", + R"s(SELECT + id, + k +FROM + ( + SELECT + *, + 1 AS label + FROM + t1 + ) AS t +CONFIG (execute_mode = 'request', values = (6, "xxx", 10, 9000) ) +)s"}, + + // TPC-C case + {R"(SELECT C_ID, C_CITY, C_STATE, C_CREDIT, C_CREDIT_LIM, C_BALANCE, C_PAYMENT_CNT, C_DELIVERY_CNT + FROM ( + SELECT C_ID, C_CITY, C_STATE, C_CREDIT, C_CREDIT_LIM, C_BALANCE, C_PAYMENT_CNT, C_DELIVERY_CNT, label FROM ( + SELECT 1 AS C_ID, 1 AS C_D_ID, 1 AS C_W_ID, "John" AS C_FIRST, "M" AS C_MIDDLE, "Smith" AS C_LAST, "123 Main St" AS C_STREET_1, "Apt 101" AS C_STREET_2, "Springfield" AS C_CITY, "IL" AS C_STATE, 12345 AS C_ZIP, "555-123-4567" AS C_PHONE, timestamp("2024-01-01 00:00:00") AS C_SINCE, "BC" AS C_CREDIT, 10000.0 AS C_CREDIT_LIM, 0.5 AS C_DISCOUNT, 5000.0 AS C_BALANCE, 0.0 AS C_YTD_PAYMENT, 0 AS C_PAYMENT_CNT, 0 AS C_DELIVERY_CNT, "Additional customer data..." AS C_DATA, 0 as label + UNION ALL + SELECT *, 1 as label FROM CUSTOMER + ) t + ) t WHERE label = 0)", + R"s( +SELECT + C_ID, + C_CITY, + C_STATE, + C_CREDIT, + C_CREDIT_LIM, + C_BALANCE, + C_PAYMENT_CNT, + C_DELIVERY_CNT +FROM + ( + SELECT + C_ID, + C_CITY, + C_STATE, + C_CREDIT, + C_CREDIT_LIM, + C_BALANCE, + C_PAYMENT_CNT, + C_DELIVERY_CNT, + label + FROM + ( + SELECT + *, + 1 AS label + FROM + CUSTOMER + ) AS t + ) AS t +CONFIG (execute_mode = 'request', values = (1, 1, 1, "John", "M", "Smith", "123 Main St", "Apt 101", +"Springfield", "IL", 12345, "555-123-4567", timestamp("2024-01-01 00:00:00"), "BC", 10000.0, 0.5, 5000.0, +0.0, 0, 0, "Additional customer data...") ) + )s"}, + + {R"( +SELECT C_ID, C_CITY, C_STATE, C_CREDIT, C_CREDIT_LIM, C_BALANCE, C_PAYMENT_CNT, C_DELIVERY_CNT + FROM ( + SELECT C_ID, C_CITY, C_STATE, C_CREDIT, C_CREDIT_LIM, C_BALANCE, C_PAYMENT_CNT, C_DELIVERY_CNT, label FROM ( + SELECT 1 AS C_ID, 1 AS C_D_ID, 1 AS C_W_ID, "John" AS C_FIRST, "M" AS C_MIDDLE, "Smith" AS C_LAST, "123 Main St" AS C_STREET_1, "Apt 101" AS C_STREET_2, "Springfield" AS C_CITY, "IL" AS C_STATE, 12345 AS C_ZIP, "555-123-4567" AS C_PHONE, timestamp("2024-01-01 00:00:00") AS C_SINCE, "BC" AS C_CREDIT, 10000.0 AS C_CREDIT_LIM, 0.5 AS C_DISCOUNT, 9000.0 AS C_BALANCE, 0.0 AS C_YTD_PAYMENT, 0 AS C_PAYMENT_CNT, 0 AS C_DELIVERY_CNT, "Additional customer data..." AS C_DATA, 0 as label + UNION ALL + SELECT *, 1 as label FROM CUSTOMER + ) t + ) t WHERE label = 0)", + R"( +SELECT + C_ID, + C_CITY, + C_STATE, + C_CREDIT, + C_CREDIT_LIM, + C_BALANCE, + C_PAYMENT_CNT, + C_DELIVERY_CNT +FROM + ( + SELECT + C_ID, + C_CITY, + C_STATE, + C_CREDIT, + C_CREDIT_LIM, + C_BALANCE, + C_PAYMENT_CNT, + C_DELIVERY_CNT, + label + FROM + ( + SELECT + *, + 1 AS label + FROM + CUSTOMER + ) AS t + ) AS t +CONFIG (execute_mode = 'request', values = (1, 1, 1, "John", "M", "Smith", "123 Main St", "Apt 101", +"Springfield", "IL", 12345, "555-123-4567", timestamp("2024-01-01 00:00:00"), "BC", 10000.0, 0.5, 9000.0, +0.0, 0, 0, "Additional customer data...") ) +)"}, +}; + +INSTANTIATE_TEST_SUITE_P(Rules, ASTRewriterTest, ::testing::ValuesIn(strip_cases)); + +TEST_P(ASTRewriterTest, Correctness) { + auto& c = GetParam(); + + auto s = hybridse::rewriter::Rewrite(c.in); + ASSERT_TRUE(s.ok()) << s.status(); + + ASSERT_EQ(absl::StripAsciiWhitespace(c.out), absl::StripAsciiWhitespace(s.value())); std::unique_ptr out; auto ss = ::hybridse::plan::ParseStatement(s.value(), &out); From 129570814d349e386f3e7f0a30eeb1951ed1fc91 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Sat, 25 May 2024 13:19:15 +0000 Subject: [PATCH 5/5] feat(rewrite): enable ansi sql rewriter in `ExecuteSQL` You may explicitly set this feature on via `set session ansi_sql_rewriter = 'true'` TODO: this rewriter feature should be off by default --- hybridse/src/rewriter/ast_rewriter.cc | 172 ++++++++++++++------- hybridse/src/rewriter/ast_rewriter_test.cc | 21 ++- src/sdk/sql_cluster_router.cc | 39 ++++- src/sdk/sql_cluster_router.h | 2 + 4 files changed, 170 insertions(+), 64 deletions(-) diff --git a/hybridse/src/rewriter/ast_rewriter.cc b/hybridse/src/rewriter/ast_rewriter.cc index de36b6d3110..9dc90ffdee3 100644 --- a/hybridse/src/rewriter/ast_rewriter.cc +++ b/hybridse/src/rewriter/ast_rewriter.cc @@ -19,6 +19,7 @@ #include #include +#include "absl/cleanup/cleanup.h" #include "plan/plan_api.h" #include "zetasql/parser/parse_tree_manual.h" #include "zetasql/parser/parser.h" @@ -84,20 +85,20 @@ class LastJoinRewriteUnparser : public zetasql::parser::Unparser { break; } - auto select = subquery->query_expr()->GetAsOrNull(); - if (!select) { + auto inner_select = subquery->query_expr()->GetAsOrNull(); + if (!inner_select) { break; } // select have window - if (select->window_clause() == nullptr || select->from_clause() == nullptr) { + if (inner_select->window_clause() == nullptr || inner_select->from_clause() == nullptr) { break; } // 3. CHECK FROM CLAUSE: must 't1 LEFT JOIN t2 on t1.key = t2.key' - if (!select->from_clause()) { + if (!inner_select->from_clause()) { break; } - auto join = select->from_clause()->table_expression()->GetAsOrNull(); + auto join = inner_select->from_clause()->table_expression()->GetAsOrNull(); if (join == nullptr || join->join_type() != zetasql::ASTJoin::LEFT || join->on_clause() == nullptr) { break; } @@ -116,7 +117,7 @@ class LastJoinRewriteUnparser : public zetasql::parser::Unparser { // 3. CHECK row_id is row_number() over w FROM select_list bool found = false; absl::string_view window_name; - for (auto col : select->select_list()->columns()) { + for (auto col : inner_select->select_list()->columns()) { if (col->alias() && col->alias()->GetAsStringView() == filter_col) { auto agg_func = col->expression()->GetAsOrNull(); if (!agg_func || !agg_func->function()) { @@ -132,6 +133,7 @@ class LastJoinRewriteUnparser : public zetasql::parser::Unparser { auto ph = agg_func->function()->function(); if (ph->num_names() == 1 && absl::AsciiStrToLower(ph->first_name()->GetAsStringView()) == "row_number") { + opt_out_row_number_col_ = col; found = true; break; } @@ -143,11 +145,11 @@ class LastJoinRewriteUnparser : public zetasql::parser::Unparser { // 4. CHECK WINDOW CLAUSE { - if (select->window_clause()->windows().size() != 1) { + if (inner_select->window_clause()->windows().size() != 1) { // targeting single window only break; } - auto win = select->window_clause()->windows().front(); + auto win = inner_select->window_clause()->windows().front(); if (win->name()->GetAsStringView() != window_name) { break; } @@ -193,39 +195,48 @@ class LastJoinRewriteUnparser : public zetasql::parser::Unparser { // rewrite { - PrintOpenParenIfNeeded(node); - println(); - print("SELECT"); - if (node->hint() != nullptr) { - node->hint()->Accept(this, data); - } - if (node->anonymization_options() != nullptr) { - print("WITH ANONYMIZATION OPTIONS"); - node->anonymization_options()->Accept(this, data); - } - if (node->distinct()) { - print("DISTINCT"); - } + opt_out_window_ = inner_select->window_clause(); + opt_out_where_ = node->where_clause(); + opt_join_ = join; + opt_in_last_join_order_by_ = e; + absl::Cleanup clean = [&]() { + opt_out_window_ = nullptr; + opt_out_where_ = nullptr; + opt_out_row_number_col_ = nullptr; + opt_join_ = nullptr; + }; + + // inline zetasql::parser::Unparser::visitASTSelect(node, data); + { + PrintOpenParenIfNeeded(node); + println(); + print("SELECT"); + if (node->hint() != nullptr) { + node->hint()->Accept(this, data); + } + if (node->anonymization_options() != nullptr) { + print("WITH ANONYMIZATION OPTIONS"); + node->anonymization_options()->Accept(this, data); + } + if (node->distinct()) { + print("DISTINCT"); + } - // Visit all children except hint() and anonymization_options, which we - // processed above. We can't just use visitASTChildren(node, data) because - // we need to insert the DISTINCT modifier after the hint and anonymization - // nodes and before everything else. - for (int i = 0; i < node->num_children(); ++i) { - const zetasql::ASTNode* child = node->child(i); - if (child == node->from_clause()) { - // this from subquery will simplified to join - println(); - print("FROM"); - visitASTJoinRewrited(join, e, data); - } else if (child != node->hint() && child != node->anonymization_options() && - child != node->where_clause()) { - child->Accept(this, data); + // Visit all children except hint() and anonymization_options, which we + // processed above. We can't just use visitASTChildren(node, data) because + // we need to insert the DISTINCT modifier after the hint and anonymization + // nodes and before everything else. + for (int i = 0; i < node->num_children(); ++i) { + const zetasql::ASTNode* child = node->child(i); + if (child != node->hint() && child != node->anonymization_options()) { + child->Accept(this, data); + } } + + println(); + PrintCloseParenIfNeeded(node); } - println(); - PrintCloseParenIfNeeded(node); return; } } @@ -236,33 +247,75 @@ class LastJoinRewriteUnparser : public zetasql::parser::Unparser { zetasql::parser::Unparser::visitASTSelect(node, data); } - void visitASTJoinRewrited(const zetasql::ASTJoin* node, const zetasql::ASTPathExpression* order, void* data) { - node->child(0)->Accept(this, data); + void visitASTJoin(const zetasql::ASTJoin* node, void* data) override { + if (opt_join_ && opt_join_ == node) { + node->child(0)->Accept(this, data); - if (node->join_type() == zetasql::ASTJoin::COMMA) { - print(","); - } else { + if (node->join_type() == zetasql::ASTJoin::COMMA) { + print(","); + } else { + println(); + if (node->natural()) { + print("NATURAL"); + } + print("LAST"); + print(node->GetSQLForJoinHint()); + + print("JOIN"); + } println(); - if (node->natural()) { - print("NATURAL"); + + // This will print hints, the rhs, and the ON or USING clause. + for (int i = 1; i < node->num_children(); i++) { + node->child(i)->Accept(this, data); + if (opt_in_last_join_order_by_ && node->child(i)->IsTableExpression()) { + print("ORDER BY"); + opt_in_last_join_order_by_->Accept(this, data); + } } - print("LAST"); - print(node->GetSQLForJoinHint()); - print("JOIN"); + return; } - println(); - // This will print hints, the rhs, and the ON or USING clause. - for (int i = 1; i < node->num_children(); i++) { - node->child(i)->Accept(this, data); - if (node->child(i) == node->rhs() && order) { - // optional order by after rhs - print("ORDER BY"); - order->Accept(this, data); + zetasql::parser::Unparser::visitASTJoin(node, data); + } + + void visitASTSelectList(const zetasql::ASTSelectList* node, void* data) override { + println(); + { + for (int i = 0; i < node->num_children(); i++) { + if (opt_out_row_number_col_ && node->columns(i) == opt_out_row_number_col_) { + continue; + } + if (i > 0) { + println(","); + } + node->child(i)->Accept(this, data); } } } + + void visitASTWindowClause(const zetasql::ASTWindowClause* node, void* data) override { + if (opt_out_window_ && opt_out_window_ == node) { + return; + } + + zetasql::parser::Unparser::visitASTWindowClause(node, data); + } + + void visitASTWhereClause(const zetasql::ASTWhereClause* node, void* data) override { + if (opt_out_where_ && opt_out_where_ == node) { + return; + } + zetasql::parser::Unparser::visitASTWhereClause(node, data); + } + + private: + const zetasql::ASTWindowClause* opt_out_window_ = nullptr; + const zetasql::ASTWhereClause* opt_out_where_ = nullptr; + const zetasql::ASTSelectColumn* opt_out_row_number_col_ = nullptr; + const zetasql::ASTJoin* opt_join_ = nullptr; + const zetasql::ASTPathExpression* opt_in_last_join_order_by_ = nullptr; }; // SELECT: @@ -346,9 +399,14 @@ class RequestQueryRewriteUnparser : public zetasql::parser::Unparser { } void visitASTQueryStatement(const zetasql::ASTQueryStatement* node, void* data) override { - visitASTQuery(node->query(), data); - if (!list_.empty()) { + node->query()->Accept(this, data); + if (!list_.empty() && !node->config_clause()) { constSelectListAsConfigClause(list_, data); + } else { + if (node->config_clause() != nullptr) { + println(); + node->config_clause()->Accept(this, data); + } } } diff --git a/hybridse/src/rewriter/ast_rewriter_test.cc b/hybridse/src/rewriter/ast_rewriter_test.cc index 4dcd4076dac..7585ada71a6 100644 --- a/hybridse/src/rewriter/ast_rewriter_test.cc +++ b/hybridse/src/rewriter/ast_rewriter_test.cc @@ -35,13 +35,14 @@ struct Case { class ASTRewriterTest : public ::testing::TestWithParam {}; std::vector strip_cases = { + // eliminate LEFT JOIN WINDOW -> LAST JOIN {R"s( SELECT id, val, k, ts, idr, valr FROM ( SELECT t1.*, t2.id as idr, t2.val as valr, row_number() over w as any_id FROM t1 LEFT JOIN t2 ON t1.k = t2.k WINDOW w as (PARTITION BY t1.id,t1.k order by t2.ts desc) ) t WHERE any_id = 1)s", - R"( + R"e( SELECT id, val, @@ -50,11 +51,19 @@ SELECT idr, valr FROM - t1 - LAST JOIN - t2 - ORDER BY t2.ts - ON t1.k = t2.k)"}, + ( + SELECT + t1.*, + t2.id AS idr, + t2.val AS valr + FROM + t1 + LAST JOIN + t2 + ORDER BY t2.ts + ON t1.k = t2.k + ) AS t +)e"}, {R"( SELECT id, k, agg FROM ( diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index e58eb8cd2cc..068538ba5e0 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -50,6 +50,7 @@ #include "plan/plan_api.h" #include "proto/fe_common.pb.h" #include "proto/tablet.pb.h" +#include "rewriter/ast_rewriter.h" #include "rpc/rpc_client.h" #include "schema/schema_adapter.h" #include "sdk/base.h" @@ -2676,14 +2677,34 @@ std::shared_ptr SQLClusterRouter::ExecuteSQL(const std } std::shared_ptr SQLClusterRouter::ExecuteSQL( - const std::string& db, const std::string& sql, std::shared_ptr parameter, + const std::string& db, const std::string& str, std::shared_ptr parameter, bool is_online_mode, bool is_sync_job, int offline_job_timeout, hybridse::sdk::Status* status) { RET_IF_NULL_AND_WARN(status, "output status is nullptr"); // functions we called later may not change the status if it's succeed. So if we pass error status here, we'll get a // fake error status->SetOK(); + + std::string sql = str; hybridse::vm::SqlContext ctx; + if (ANSISQLRewriterEnabled()) { + // If true, enable the ANSI SQL rewriter that would rewrite some SQL query + // for pre-defined pattern to OpenMLDB SQL extensions. Rewrite phase is before general SQL compilation. + // + // OpenMLDB SQL extensions, such as request mode query or LAST JOIN, would be helpful + // to simplify those that comes from like SparkSQL, and reserve the same semantics meaning. + // + // Rewrite rules are based on ASTNode, possibly lack some semantic checks. Turn it off if things + // go abnormal during rewrite phase. + auto s = hybridse::rewriter::Rewrite(sql); + if (s.ok()) { + LOG(INFO) << "rewrited: " << s.value(); + sql = s.value(); + } else { + LOG(WARNING) << s.status(); + } + } ctx.sql = sql; + auto sql_status = hybridse::plan::PlanAPI::CreatePlanTreeFromScript(&ctx); if (!sql_status.isOK()) { COPY_PREPEND_AND_WARN(status, sql_status, "create logic plan tree failed"); @@ -3196,6 +3217,18 @@ bool SQLClusterRouter::IsSyncJob() { return false; } +bool SQLClusterRouter::ANSISQLRewriterEnabled() { + // TODO(xxx): mark fn const + + std::lock_guard<::openmldb::base::SpinMutex> lock(mu_); + auto it = session_variables_.find("ansi_sql_rewriter"); + if (it != session_variables_.end() && it->second == "false") { + return false; + } + // TODO(xxx): always disable by default + return true; +} + int SQLClusterRouter::GetJobTimeout() { std::lock_guard<::openmldb::base::SpinMutex> lock(mu_); auto it = session_variables_.find("job_timeout"); @@ -3267,6 +3300,10 @@ ::hybridse::sdk::Status SQLClusterRouter::SetVariable(hybridse::node::SetPlanNod return {StatusCode::kCmdError, "Fail to parse spark config, set like 'spark.executor.memory=2g;spark.executor.cores=2'"}; } + } else if (key == "ansi_sql_rewriter") { + if (value != "true" && value != "false") { + return {StatusCode::kCmdError, "the value of " + key + " must be true|false"}; + } } else { return {}; } diff --git a/src/sdk/sql_cluster_router.h b/src/sdk/sql_cluster_router.h index 3d13cafa240..e917c170a14 100644 --- a/src/sdk/sql_cluster_router.h +++ b/src/sdk/sql_cluster_router.h @@ -443,6 +443,8 @@ class SQLClusterRouter : public SQLRouter { const base::Slice& value, const std::vector>& tablets); + bool ANSISQLRewriterEnabled(); + private: std::shared_ptr options_; std::string db_;