Skip to content

Commit

Permalink
feat: improve support for postgres extended protocol (#4721)
Browse files Browse the repository at this point in the history
* feat: improve support for postgres extended protocol

* fix: lint fix

* fix: test code

* fix: adopt upstream

* refactor: remove dup code

* refactor: avoid copy on error message
  • Loading branch information
sunng87 authored Sep 19, 2024
1 parent 52d627e commit 8786624
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 247 deletions.
18 changes: 13 additions & 5 deletions src/query/src/datafusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,19 @@ impl QueryEngine for DatafusionQueryEngine {
query_ctx: QueryContextRef,
) -> Result<DescribeResult> {
let ctx = self.engine_context(query_ctx);
let optimised_plan = self.optimize(&ctx, &plan)?;
Ok(DescribeResult {
schema: optimised_plan.schema()?,
logical_plan: optimised_plan,
})
if let Ok(optimised_plan) = self.optimize(&ctx, &plan) {
Ok(DescribeResult {
schema: optimised_plan.schema()?,
logical_plan: optimised_plan,
})
} else {
// Table's like those in information_schema cannot be optimized when
// it contains parameters. So we fallback to original plans.
Ok(DescribeResult {
schema: plan.schema()?,
logical_plan: plan,
})
}
}

async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
Expand Down
22 changes: 11 additions & 11 deletions src/servers/src/postgres/fixtures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,19 @@ static SET_TRANSACTION_PATTERN: Lazy<Regex> =
static TRANSACTION_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(BEGIN|ROLLBACK|COMMIT);?").unwrap());

/// Test if given query statement matches the patterns
pub(crate) fn matches(query: &str) -> bool {
TRANSACTION_PATTERN.captures(query).is_some()
|| SHOW_PATTERN.captures(query).is_some()
|| SET_TRANSACTION_PATTERN.is_match(query)
}

/// Process unsupported SQL and return fixed result as a compatibility solution
pub(crate) fn process<'a>(
query: &str,
_query_ctx: QueryContextRef,
) -> Option<PgWireResult<Vec<Response<'a>>>> {
pub(crate) fn process<'a>(query: &str, _query_ctx: QueryContextRef) -> Option<Vec<Response<'a>>> {
// Transaction directives:
if let Some(tx) = TRANSACTION_PATTERN.captures(query) {
let tx_tag = &tx[1];
Some(Ok(vec![Response::Execution(Tag::new(
&tx_tag.to_uppercase(),
))]))
Some(vec![Response::Execution(Tag::new(&tx_tag.to_uppercase()))])
} else if let Some(show_var) = SHOW_PATTERN.captures(query) {
let show_var = show_var[1].to_lowercase();
if let Some(value) = VAR_VALUES.get(&show_var.as_ref()) {
Expand All @@ -81,12 +83,12 @@ pub(crate) fn process<'a>(
vec![vec![value.to_string()]],
));

Some(Ok(vec![Response::Query(QueryResponse::new(schema, data))]))
Some(vec![Response::Query(QueryResponse::new(schema, data))])
} else {
None
}
} else if SET_TRANSACTION_PATTERN.is_match(query) {
Some(Ok(vec![Response::Execution(Tag::new("SET"))]))
Some(vec![Response::Execution(Tag::new("SET"))])
} else {
None
}
Expand All @@ -101,7 +103,6 @@ mod test {
fn assert_tag(q: &str, t: &str, query_context: QueryContextRef) {
if let Response::Execution(tag) = process(q, query_context.clone())
.unwrap_or_else(|| panic!("fail to match {}", q))
.expect("unexpected error")
.remove(0)
{
assert_eq!(Tag::new(t), tag);
Expand All @@ -113,7 +114,6 @@ mod test {
fn get_data<'a>(q: &str, query_context: QueryContextRef) -> QueryResponse<'a> {
if let Response::Query(resp) = process(q, query_context.clone())
.unwrap_or_else(|| panic!("fail to match {}", q))
.expect("unexpected error")
.remove(0)
{
resp
Expand Down
49 changes: 48 additions & 1 deletion src/servers/src/postgres/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,13 @@ impl SimpleQueryHandler for PostgresServerHandler {
.with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
.start_timer();

if query.is_empty() {
// early return if query is empty
return Ok(vec![Response::EmptyQuery]);
}

if let Some(resps) = fixtures::process(query, query_ctx.clone()) {
resps
Ok(resps)
} else {
let outputs = self.query_handler.do_query(query, query_ctx.clone()).await;

Expand Down Expand Up @@ -184,6 +189,16 @@ impl QueryParser for DefaultQueryParser {
async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
let query_ctx = self.session.new_query_context();

// do not parse if query is empty or matches rules
if sql.is_empty() || fixtures::matches(sql) {
return Ok(SqlPlan {
query: sql.to_owned(),
plan: None,
schema: None,
});
}

let mut stmts =
ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}, ParseOptions::default())
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
Expand All @@ -193,6 +208,7 @@ impl QueryParser for DefaultQueryParser {
))))
} else {
let stmt = stmts.remove(0);

let describe_result = self
.query_handler
.do_describe(stmt, query_ctx)
Expand Down Expand Up @@ -244,6 +260,16 @@ impl ExtendedQueryHandler for PostgresServerHandler {

let sql_plan = &portal.statement.statement;

if sql_plan.query.is_empty() {
// early return if query is empty
return Ok(Response::EmptyQuery);
}

if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
// if the statement matches our predefined rules, return it early
return Ok(resps.remove(0));
}

let output = if let Some(plan) = &sql_plan.plan {
let plan = plan
.replace_params_with_values(parameters_to_scalar_values(plan, portal)?.as_ref())
Expand Down Expand Up @@ -297,6 +323,17 @@ impl ExtendedQueryHandler for PostgresServerHandler {
.map(|fields| DescribeStatementResponse::new(param_types, fields))
.map_err(|e| PgWireError::ApiError(Box::new(e)))
} else {
if let Some(mut resp) =
fixtures::process(&sql_plan.query, self.session.new_query_context())
{
if let Response::Query(query_response) = resp.remove(0) {
return Ok(DescribeStatementResponse::new(
param_types,
(*query_response.row_schema()).clone(),
));
}
}

Ok(DescribeStatementResponse::new(param_types, vec![]))
}
}
Expand All @@ -317,6 +354,16 @@ impl ExtendedQueryHandler for PostgresServerHandler {
.map(DescribePortalResponse::new)
.map_err(|e| PgWireError::ApiError(Box::new(e)))
} else {
if let Some(mut resp) =
fixtures::process(&sql_plan.query, self.session.new_query_context())
{
if let Response::Query(query_response) = resp.remove(0) {
return Ok(DescribePortalResponse::new(
(*query_response.row_schema()).clone(),
));
}
}

Ok(DescribePortalResponse::new(vec![]))
}
}
Expand Down
Loading

0 comments on commit 8786624

Please sign in to comment.