Skip to content

Commit

Permalink
Improve round scalar function unparsing for Postgres (#12744)
Browse files Browse the repository at this point in the history
* Postgres: enforce required `NUMERIC` type for `round` scalar function (#34)

Includes initial support for dialects to override scalar functions unparsing

* Document scalar_function_to_sql_overrides fn
  • Loading branch information
sgrebnov authored Oct 6, 2024
1 parent ecb0044 commit 9b492c6
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 126 deletions.
119 changes: 118 additions & 1 deletion datafusion/sql/src/unparser/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@
use std::sync::Arc;

use arrow_schema::TimeUnit;
use datafusion_expr::Expr;
use regex::Regex;
use sqlparser::{
ast::{self, Ident, ObjectName, TimezoneInfo},
ast::{self, Function, Ident, ObjectName, TimezoneInfo},
keywords::ALL_KEYWORDS,
};

use datafusion_common::Result;

use super::{utils::date_part_to_sql, Unparser};

/// `Dialect` to use for Unparsing
///
/// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`)
Expand Down Expand Up @@ -108,6 +113,18 @@ pub trait Dialect: Send + Sync {
fn supports_column_alias_in_table_alias(&self) -> bool {
true
}

/// Allows the dialect to override scalar function unparsing if the dialect has specific rules.
/// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is
/// a custom implementation for the function.
fn scalar_function_to_sql_overrides(
&self,
_unparser: &Unparser,
_func_name: &str,
_args: &[Expr],
) -> Result<Option<ast::Expr>> {
Ok(None)
}
}

/// `IntervalStyle` to use for unparsing
Expand Down Expand Up @@ -171,6 +188,67 @@ impl Dialect for PostgreSqlDialect {
fn float64_ast_dtype(&self) -> sqlparser::ast::DataType {
sqlparser::ast::DataType::DoublePrecision
}

fn scalar_function_to_sql_overrides(
&self,
unparser: &Unparser,
func_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
if func_name == "round" {
return Ok(Some(
self.round_to_sql_enforce_numeric(unparser, func_name, args)?,
));
}

Ok(None)
}
}

impl PostgreSqlDialect {
fn round_to_sql_enforce_numeric(
&self,
unparser: &Unparser,
func_name: &str,
args: &[Expr],
) -> Result<ast::Expr> {
let mut args = unparser.function_args_to_sql(args)?;

// Enforce the first argument to be Numeric
if let Some(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(expr))) =
args.first_mut()
{
if let ast::Expr::Cast { data_type, .. } = expr {
// Don't create an additional cast wrapper if we can update the existing one
*data_type = ast::DataType::Numeric(ast::ExactNumberInfo::None);
} else {
// Wrap the expression in a new cast
*expr = ast::Expr::Cast {
kind: ast::CastKind::Cast,
expr: Box::new(expr.clone()),
data_type: ast::DataType::Numeric(ast::ExactNumberInfo::None),
format: None,
};
}
}

Ok(ast::Expr::Function(Function {
name: ast::ObjectName(vec![Ident {
value: func_name.to_string(),
quote_style: None,
}]),
args: ast::FunctionArguments::List(ast::FunctionArgumentList {
duplicate_treatment: None,
args,
clauses: vec![],
}),
filter: None,
null_treatment: None,
over: None,
within_group: vec![],
parameters: ast::FunctionArguments::None,
}))
}
}

pub struct MySqlDialect {}
Expand Down Expand Up @@ -211,6 +289,19 @@ impl Dialect for MySqlDialect {
) -> ast::DataType {
ast::DataType::Datetime(None)
}

fn scalar_function_to_sql_overrides(
&self,
unparser: &Unparser,
func_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
if func_name == "date_part" {
return date_part_to_sql(unparser, self.date_field_extract_style(), args);
}

Ok(None)
}
}

pub struct SqliteDialect {}
Expand All @@ -231,6 +322,19 @@ impl Dialect for SqliteDialect {
fn supports_column_alias_in_table_alias(&self) -> bool {
false
}

fn scalar_function_to_sql_overrides(
&self,
unparser: &Unparser,
func_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
if func_name == "date_part" {
return date_part_to_sql(unparser, self.date_field_extract_style(), args);
}

Ok(None)
}
}

pub struct CustomDialect {
Expand Down Expand Up @@ -339,6 +443,19 @@ impl Dialect for CustomDialect {
fn supports_column_alias_in_table_alias(&self) -> bool {
self.supports_column_alias_in_table_alias
}

fn scalar_function_to_sql_overrides(
&self,
unparser: &Unparser,
func_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
if func_name == "date_part" {
return date_part_to_sql(unparser, self.date_field_extract_style(), args);
}

Ok(None)
}
}

/// `CustomDialectBuilder` to build `CustomDialect` using builder pattern
Expand Down
Loading

0 comments on commit 9b492c6

Please sign in to comment.