Skip to content

Commit

Permalink
feat(rewrite): enable ansi sql rewriter in ExecuteSQL
Browse files Browse the repository at this point in the history
You may explicitly set this feature on via `set session ansi_sql_rewriter
= 'true'`

TODO: this rewriter feature should be off by default
  • Loading branch information
aceforeverd committed May 25, 2024
1 parent 7b611f8 commit 1295708
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 64 deletions.
172 changes: 115 additions & 57 deletions hybridse/src/rewriter/ast_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <memory>
#include <vector>

#include "absl/cleanup/cleanup.h"
#include "plan/plan_api.h"
#include "zetasql/parser/parse_tree_manual.h"
#include "zetasql/parser/parser.h"
Expand Down Expand Up @@ -84,20 +85,20 @@ class LastJoinRewriteUnparser : public zetasql::parser::Unparser {
break;

Check warning on line 85 in hybridse/src/rewriter/ast_rewriter.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/rewriter/ast_rewriter.cc#L85

Added line #L85 was not covered by tests
}

auto select = subquery->query_expr()->GetAsOrNull<zetasql::ASTSelect>();
if (!select) {
auto inner_select = subquery->query_expr()->GetAsOrNull<zetasql::ASTSelect>();
if (!inner_select) {
break;
}
// select have window
if (select->window_clause() == nullptr || select->from_clause() == nullptr) {
if (inner_select->window_clause() == nullptr || inner_select->from_clause() == nullptr) {
break;
}

// 3. CHECK FROM CLAUSE: must 't1 LEFT JOIN t2 on t1.key = t2.key'
if (!select->from_clause()) {
if (!inner_select->from_clause()) {
break;

Check warning on line 99 in hybridse/src/rewriter/ast_rewriter.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/rewriter/ast_rewriter.cc#L99

Added line #L99 was not covered by tests
}
auto join = select->from_clause()->table_expression()->GetAsOrNull<zetasql::ASTJoin>();
auto join = inner_select->from_clause()->table_expression()->GetAsOrNull<zetasql::ASTJoin>();
if (join == nullptr || join->join_type() != zetasql::ASTJoin::LEFT || join->on_clause() == nullptr) {
break;
}
Expand All @@ -116,7 +117,7 @@ class LastJoinRewriteUnparser : public zetasql::parser::Unparser {
// 3. CHECK row_id is row_number() over w FROM select_list
bool found = false;
absl::string_view window_name;
for (auto col : select->select_list()->columns()) {
for (auto col : inner_select->select_list()->columns()) {
if (col->alias() && col->alias()->GetAsStringView() == filter_col) {
auto agg_func = col->expression()->GetAsOrNull<zetasql::ASTAnalyticFunctionCall>();
if (!agg_func || !agg_func->function()) {
Expand All @@ -132,6 +133,7 @@ class LastJoinRewriteUnparser : public zetasql::parser::Unparser {
auto ph = agg_func->function()->function();
if (ph->num_names() == 1 &&
absl::AsciiStrToLower(ph->first_name()->GetAsStringView()) == "row_number") {
opt_out_row_number_col_ = col;
found = true;
break;
}
Expand All @@ -143,11 +145,11 @@ class LastJoinRewriteUnparser : public zetasql::parser::Unparser {

// 4. CHECK WINDOW CLAUSE
{
if (select->window_clause()->windows().size() != 1) {
if (inner_select->window_clause()->windows().size() != 1) {
// targeting single window only
break;

Check warning on line 150 in hybridse/src/rewriter/ast_rewriter.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/rewriter/ast_rewriter.cc#L150

Added line #L150 was not covered by tests
}
auto win = select->window_clause()->windows().front();
auto win = inner_select->window_clause()->windows().front();
if (win->name()->GetAsStringView() != window_name) {
break;

Check warning on line 154 in hybridse/src/rewriter/ast_rewriter.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/rewriter/ast_rewriter.cc#L154

Added line #L154 was not covered by tests
}
Expand Down Expand Up @@ -193,39 +195,48 @@ class LastJoinRewriteUnparser : public zetasql::parser::Unparser {

// rewrite
{
PrintOpenParenIfNeeded(node);
println();
print("SELECT");
if (node->hint() != nullptr) {
node->hint()->Accept(this, data);
}
if (node->anonymization_options() != nullptr) {
print("WITH ANONYMIZATION OPTIONS");
node->anonymization_options()->Accept(this, data);
}
if (node->distinct()) {
print("DISTINCT");
}
opt_out_window_ = inner_select->window_clause();
opt_out_where_ = node->where_clause();
opt_join_ = join;
opt_in_last_join_order_by_ = e;
absl::Cleanup clean = [&]() {
opt_out_window_ = nullptr;
opt_out_where_ = nullptr;
opt_out_row_number_col_ = nullptr;
opt_join_ = nullptr;
};

// inline zetasql::parser::Unparser::visitASTSelect(node, data);
{
PrintOpenParenIfNeeded(node);
println();
print("SELECT");
if (node->hint() != nullptr) {
node->hint()->Accept(this, data);

Check warning on line 215 in hybridse/src/rewriter/ast_rewriter.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/rewriter/ast_rewriter.cc#L215

Added line #L215 was not covered by tests
}
if (node->anonymization_options() != nullptr) {
print("WITH ANONYMIZATION OPTIONS");
node->anonymization_options()->Accept(this, data);

Check warning on line 219 in hybridse/src/rewriter/ast_rewriter.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/rewriter/ast_rewriter.cc#L218-L219

Added lines #L218 - L219 were not covered by tests
}
if (node->distinct()) {
print("DISTINCT");

Check warning on line 222 in hybridse/src/rewriter/ast_rewriter.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/rewriter/ast_rewriter.cc#L222

Added line #L222 was not covered by tests
}

// Visit all children except hint() and anonymization_options, which we
// processed above. We can't just use visitASTChildren(node, data) because
// we need to insert the DISTINCT modifier after the hint and anonymization
// nodes and before everything else.
for (int i = 0; i < node->num_children(); ++i) {
const zetasql::ASTNode* child = node->child(i);
if (child == node->from_clause()) {
// this from subquery will simplified to join
println();
print("FROM");
visitASTJoinRewrited(join, e, data);
} else if (child != node->hint() && child != node->anonymization_options() &&
child != node->where_clause()) {
child->Accept(this, data);
// Visit all children except hint() and anonymization_options, which we
// processed above. We can't just use visitASTChildren(node, data) because
// we need to insert the DISTINCT modifier after the hint and anonymization
// nodes and before everything else.
for (int i = 0; i < node->num_children(); ++i) {
const zetasql::ASTNode* child = node->child(i);
if (child != node->hint() && child != node->anonymization_options()) {
child->Accept(this, data);
}
}

println();
PrintCloseParenIfNeeded(node);
}

println();
PrintCloseParenIfNeeded(node);
return;
}
}
Expand All @@ -236,33 +247,75 @@ class LastJoinRewriteUnparser : public zetasql::parser::Unparser {
zetasql::parser::Unparser::visitASTSelect(node, data);
}

void visitASTJoinRewrited(const zetasql::ASTJoin* node, const zetasql::ASTPathExpression* order, void* data) {
node->child(0)->Accept(this, data);
void visitASTJoin(const zetasql::ASTJoin* node, void* data) override {
if (opt_join_ && opt_join_ == node) {
node->child(0)->Accept(this, data);

if (node->join_type() == zetasql::ASTJoin::COMMA) {
print(",");
} else {
if (node->join_type() == zetasql::ASTJoin::COMMA) {
print(",");

Check warning on line 255 in hybridse/src/rewriter/ast_rewriter.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/rewriter/ast_rewriter.cc#L255

Added line #L255 was not covered by tests
} else {
println();
if (node->natural()) {
print("NATURAL");

Check warning on line 259 in hybridse/src/rewriter/ast_rewriter.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/rewriter/ast_rewriter.cc#L259

Added line #L259 was not covered by tests
}
print("LAST");
print(node->GetSQLForJoinHint());

print("JOIN");
}
println();
if (node->natural()) {
print("NATURAL");

// This will print hints, the rhs, and the ON or USING clause.
for (int i = 1; i < node->num_children(); i++) {
node->child(i)->Accept(this, data);
if (opt_in_last_join_order_by_ && node->child(i)->IsTableExpression()) {
print("ORDER BY");
opt_in_last_join_order_by_->Accept(this, data);
}
}
print("LAST");
print(node->GetSQLForJoinHint());

print("JOIN");
return;
}
println();

// This will print hints, the rhs, and the ON or USING clause.
for (int i = 1; i < node->num_children(); i++) {
node->child(i)->Accept(this, data);
if (node->child(i) == node->rhs() && order) {
// optional order by after rhs
print("ORDER BY");
order->Accept(this, data);
zetasql::parser::Unparser::visitASTJoin(node, data);

Check warning on line 280 in hybridse/src/rewriter/ast_rewriter.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/rewriter/ast_rewriter.cc#L280

Added line #L280 was not covered by tests
}

void visitASTSelectList(const zetasql::ASTSelectList* node, void* data) override {
println();
{
for (int i = 0; i < node->num_children(); i++) {
if (opt_out_row_number_col_ && node->columns(i) == opt_out_row_number_col_) {
continue;
}
if (i > 0) {
println(",");
}
node->child(i)->Accept(this, data);
}
}
}

void visitASTWindowClause(const zetasql::ASTWindowClause* node, void* data) override {
if (opt_out_window_ && opt_out_window_ == node) {
return;
}

zetasql::parser::Unparser::visitASTWindowClause(node, data);
}

void visitASTWhereClause(const zetasql::ASTWhereClause* node, void* data) override {
if (opt_out_where_ && opt_out_where_ == node) {
return;
}
zetasql::parser::Unparser::visitASTWhereClause(node, data);
}

private:
const zetasql::ASTWindowClause* opt_out_window_ = nullptr;
const zetasql::ASTWhereClause* opt_out_where_ = nullptr;
const zetasql::ASTSelectColumn* opt_out_row_number_col_ = nullptr;
const zetasql::ASTJoin* opt_join_ = nullptr;
const zetasql::ASTPathExpression* opt_in_last_join_order_by_ = nullptr;
};

// SELECT:
Expand Down Expand Up @@ -346,9 +399,14 @@ class RequestQueryRewriteUnparser : public zetasql::parser::Unparser {
}

void visitASTQueryStatement(const zetasql::ASTQueryStatement* node, void* data) override {
visitASTQuery(node->query(), data);
if (!list_.empty()) {
node->query()->Accept(this, data);
if (!list_.empty() && !node->config_clause()) {
constSelectListAsConfigClause(list_, data);
} else {
if (node->config_clause() != nullptr) {
println();
node->config_clause()->Accept(this, data);

Check warning on line 408 in hybridse/src/rewriter/ast_rewriter.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/rewriter/ast_rewriter.cc#L407-L408

Added lines #L407 - L408 were not covered by tests
}
}
}

Expand Down
21 changes: 15 additions & 6 deletions hybridse/src/rewriter/ast_rewriter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ struct Case {
class ASTRewriterTest : public ::testing::TestWithParam<Case> {};

std::vector<Case> strip_cases = {
// eliminate LEFT JOIN WINDOW -> LAST JOIN
{R"s(
SELECT id, val, k, ts, idr, valr FROM (
SELECT t1.*, t2.id as idr, t2.val as valr, row_number() over w as any_id
FROM t1 LEFT JOIN t2 ON t1.k = t2.k
WINDOW w as (PARTITION BY t1.id,t1.k order by t2.ts desc)
) t WHERE any_id = 1)s",
R"(
R"e(
SELECT
id,
val,
Expand All @@ -50,11 +51,19 @@ SELECT
idr,
valr
FROM
t1
LAST JOIN
t2
ORDER BY t2.ts
ON t1.k = t2.k)"},
(
SELECT
t1.*,
t2.id AS idr,
t2.val AS valr
FROM
t1
LAST JOIN
t2
ORDER BY t2.ts
ON t1.k = t2.k
) AS t
)e"},
{R"(
SELECT id, k, agg
FROM (
Expand Down
39 changes: 38 additions & 1 deletion src/sdk/sql_cluster_router.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include "plan/plan_api.h"
#include "proto/fe_common.pb.h"
#include "proto/tablet.pb.h"
#include "rewriter/ast_rewriter.h"
#include "rpc/rpc_client.h"
#include "schema/schema_adapter.h"
#include "sdk/base.h"
Expand Down Expand Up @@ -2676,14 +2677,34 @@ std::shared_ptr<hybridse::sdk::ResultSet> SQLClusterRouter::ExecuteSQL(const std
}

std::shared_ptr<hybridse::sdk::ResultSet> SQLClusterRouter::ExecuteSQL(
const std::string& db, const std::string& sql, std::shared_ptr<openmldb::sdk::SQLRequestRow> parameter,
const std::string& db, const std::string& str, std::shared_ptr<openmldb::sdk::SQLRequestRow> parameter,
bool is_online_mode, bool is_sync_job, int offline_job_timeout, hybridse::sdk::Status* status) {
RET_IF_NULL_AND_WARN(status, "output status is nullptr");
// functions we called later may not change the status if it's succeed. So if we pass error status here, we'll get a
// fake error
status->SetOK();

std::string sql = str;
hybridse::vm::SqlContext ctx;
if (ANSISQLRewriterEnabled()) {
// If true, enable the ANSI SQL rewriter that would rewrite some SQL query
// for pre-defined pattern to OpenMLDB SQL extensions. Rewrite phase is before general SQL compilation.
//
// OpenMLDB SQL extensions, such as request mode query or LAST JOIN, would be helpful
// to simplify those that comes from like SparkSQL, and reserve the same semantics meaning.
//
// Rewrite rules are based on ASTNode, possibly lack some semantic checks. Turn it off if things
// go abnormal during rewrite phase.
auto s = hybridse::rewriter::Rewrite(sql);
if (s.ok()) {
LOG(INFO) << "rewrited: " << s.value();
sql = s.value();
} else {
LOG(WARNING) << s.status();

Check warning on line 2703 in src/sdk/sql_cluster_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_cluster_router.cc#L2703

Added line #L2703 was not covered by tests
}
}
ctx.sql = sql;

auto sql_status = hybridse::plan::PlanAPI::CreatePlanTreeFromScript(&ctx);
if (!sql_status.isOK()) {
COPY_PREPEND_AND_WARN(status, sql_status, "create logic plan tree failed");
Expand Down Expand Up @@ -3196,6 +3217,18 @@ bool SQLClusterRouter::IsSyncJob() {
return false;
}

bool SQLClusterRouter::ANSISQLRewriterEnabled() {
// TODO(xxx): mark fn const

std::lock_guard<::openmldb::base::SpinMutex> lock(mu_);
auto it = session_variables_.find("ansi_sql_rewriter");
if (it != session_variables_.end() && it->second == "false") {
return false;

Check warning on line 3226 in src/sdk/sql_cluster_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_cluster_router.cc#L3226

Added line #L3226 was not covered by tests
}
// TODO(xxx): always disable by default
return true;
}

int SQLClusterRouter::GetJobTimeout() {
std::lock_guard<::openmldb::base::SpinMutex> lock(mu_);
auto it = session_variables_.find("job_timeout");
Expand Down Expand Up @@ -3267,6 +3300,10 @@ ::hybridse::sdk::Status SQLClusterRouter::SetVariable(hybridse::node::SetPlanNod
return {StatusCode::kCmdError,
"Fail to parse spark config, set like 'spark.executor.memory=2g;spark.executor.cores=2'"};
}
} else if (key == "ansi_sql_rewriter") {
if (value != "true" && value != "false") {
return {StatusCode::kCmdError, "the value of " + key + " must be true|false"};

Check warning on line 3305 in src/sdk/sql_cluster_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_cluster_router.cc#L3304-L3305

Added lines #L3304 - L3305 were not covered by tests
}
} else {
return {};
}
Expand Down
2 changes: 2 additions & 0 deletions src/sdk/sql_cluster_router.h
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ class SQLClusterRouter : public SQLRouter {
const base::Slice& value,
const std::vector<std::shared_ptr<::openmldb::catalog::TabletAccessor>>& tablets);

bool ANSISQLRewriterEnabled();

private:
std::shared_ptr<BasicRouterOptions> options_;
std::string db_;
Expand Down

0 comments on commit 1295708

Please sign in to comment.