Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve unparser MySQL compatibility #11589

Merged
merged 6 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 146 additions & 9 deletions datafusion/sql/src/unparser/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow_schema::TimeUnit;
use regex::Regex;
use sqlparser::{ast, keywords::ALL_KEYWORDS};
use sqlparser::{
ast::{self, Ident, ObjectName, TimezoneInfo},
keywords::ALL_KEYWORDS,
};

/// `Dialect` to use for Unparsing
///
Expand All @@ -36,8 +42,8 @@ pub trait Dialect: Send + Sync {
true
}

// Does the dialect use TIMESTAMP to represent Date64 rather than DATETIME?
// E.g. Trino, Athena and Dremio does not have DATETIME data type
/// Does the dialect use TIMESTAMP to represent Date64 rather than DATETIME?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

/// E.g. Trino, Athena and Dremio does not have DATETIME data type
fn use_timestamp_for_date64(&self) -> bool {
false
}
Expand All @@ -46,23 +52,50 @@ pub trait Dialect: Send + Sync {
IntervalStyle::PostgresVerbose
}

// Does the dialect use DOUBLE PRECISION to represent Float64 rather than DOUBLE?
// E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE
/// Does the dialect use DOUBLE PRECISION to represent Float64 rather than DOUBLE?
/// E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE
fn float64_ast_dtype(&self) -> sqlparser::ast::DataType {
sqlparser::ast::DataType::Double
}

// The SQL type to use for Arrow Utf8 unparsing
// Most dialects use VARCHAR, but some, like MySQL, require CHAR
/// The SQL type to use for Arrow Utf8 unparsing
/// Most dialects use VARCHAR, but some, like MySQL, require CHAR
fn utf8_cast_dtype(&self) -> ast::DataType {
ast::DataType::Varchar(None)
}

// The SQL type to use for Arrow LargeUtf8 unparsing
// Most dialects use TEXT, but some, like MySQL, require CHAR
/// The SQL type to use for Arrow LargeUtf8 unparsing
/// Most dialects use TEXT, but some, like MySQL, require CHAR
fn large_utf8_cast_dtype(&self) -> ast::DataType {
ast::DataType::Text
}

/// The date field extract style to use: `DateFieldExtractStyle`
fn date_field_extract_style(&self) -> DateFieldExtractStyle {
DateFieldExtractStyle::DatePart
}

/// The SQL type to use for Arrow Int64 unparsing
/// Most dialects use BigInt, but some, like MySQL, require SIGNED
fn int64_cast_dtype(&self) -> ast::DataType {
ast::DataType::BigInt(None)
}

/// The SQL type to use for Timestamp unparsing
/// Most dialects use Timestamp, but some, like MySQL, require Datetime
/// Some dialects like Dremio does not support WithTimeZone and requires always Timestamp
fn timestamp_cast_dtype(
&self,
_time_unit: &TimeUnit,
tz: &Option<Arc<str>>,
) -> ast::DataType {
let tz_info = match tz {
Some(_) => TimezoneInfo::WithTimeZone,
None => TimezoneInfo::None,
};

ast::DataType::Timestamp(None, tz_info)
}
}

/// `IntervalStyle` to use for unparsing
Expand All @@ -80,6 +113,19 @@ pub enum IntervalStyle {
MySQL,
}

/// Datetime subfield extraction style for unparsing
///
/// `<https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT>`
/// Different DBMSs follow different standards; popular ones are:
/// date_part('YEAR', date '2001-02-16')
/// EXTRACT(YEAR from date '2001-02-16')
/// Some DBMSs, like Postgres, support both, whereas others like MySQL require EXTRACT.
#[derive(Clone, Copy, PartialEq)]
pub enum DateFieldExtractStyle {
DatePart,
Extract,
}

pub struct DefaultDialect {}

impl Dialect for DefaultDialect {
Expand Down Expand Up @@ -133,6 +179,22 @@ impl Dialect for MySqlDialect {
fn large_utf8_cast_dtype(&self) -> ast::DataType {
ast::DataType::Char(None)
}

fn date_field_extract_style(&self) -> DateFieldExtractStyle {
DateFieldExtractStyle::Extract
}

fn int64_cast_dtype(&self) -> ast::DataType {
ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![])
}

fn timestamp_cast_dtype(
&self,
_time_unit: &TimeUnit,
_tz: &Option<Arc<str>>,
) -> ast::DataType {
ast::DataType::Datetime(None)
}
}

pub struct SqliteDialect {}
Expand All @@ -151,6 +213,10 @@ pub struct CustomDialect {
float64_ast_dtype: sqlparser::ast::DataType,
utf8_cast_dtype: ast::DataType,
large_utf8_cast_dtype: ast::DataType,
date_field_extract_style: DateFieldExtractStyle,
int64_cast_dtype: ast::DataType,
timestamp_cast_dtype: ast::DataType,
timestamp_tz_cast_dtype: ast::DataType,
}

impl Default for CustomDialect {
Expand All @@ -163,6 +229,13 @@ impl Default for CustomDialect {
float64_ast_dtype: sqlparser::ast::DataType::Double,
utf8_cast_dtype: ast::DataType::Varchar(None),
large_utf8_cast_dtype: ast::DataType::Text,
date_field_extract_style: DateFieldExtractStyle::DatePart,
int64_cast_dtype: ast::DataType::BigInt(None),
timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None),
timestamp_tz_cast_dtype: ast::DataType::Timestamp(
None,
TimezoneInfo::WithTimeZone,
),
}
}
}
Expand Down Expand Up @@ -206,6 +279,26 @@ impl Dialect for CustomDialect {
fn large_utf8_cast_dtype(&self) -> ast::DataType {
self.large_utf8_cast_dtype.clone()
}

fn date_field_extract_style(&self) -> DateFieldExtractStyle {
self.date_field_extract_style
}

fn int64_cast_dtype(&self) -> ast::DataType {
self.int64_cast_dtype.clone()
}

fn timestamp_cast_dtype(
&self,
_time_unit: &TimeUnit,
tz: &Option<Arc<str>>,
) -> ast::DataType {
if tz.is_some() {
self.timestamp_tz_cast_dtype.clone()
} else {
self.timestamp_cast_dtype.clone()
}
}
}

/// `CustomDialectBuilder` to build `CustomDialect` using builder pattern
Expand All @@ -230,6 +323,10 @@ pub struct CustomDialectBuilder {
float64_ast_dtype: sqlparser::ast::DataType,
utf8_cast_dtype: ast::DataType,
large_utf8_cast_dtype: ast::DataType,
date_field_extract_style: DateFieldExtractStyle,
int64_cast_dtype: ast::DataType,
timestamp_cast_dtype: ast::DataType,
timestamp_tz_cast_dtype: ast::DataType,
}

impl Default for CustomDialectBuilder {
Expand All @@ -248,6 +345,13 @@ impl CustomDialectBuilder {
float64_ast_dtype: sqlparser::ast::DataType::Double,
utf8_cast_dtype: ast::DataType::Varchar(None),
large_utf8_cast_dtype: ast::DataType::Text,
date_field_extract_style: DateFieldExtractStyle::DatePart,
int64_cast_dtype: ast::DataType::BigInt(None),
timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None),
timestamp_tz_cast_dtype: ast::DataType::Timestamp(
None,
TimezoneInfo::WithTimeZone,
),
}
}

Expand All @@ -260,6 +364,10 @@ impl CustomDialectBuilder {
float64_ast_dtype: self.float64_ast_dtype,
utf8_cast_dtype: self.utf8_cast_dtype,
large_utf8_cast_dtype: self.large_utf8_cast_dtype,
date_field_extract_style: self.date_field_extract_style,
int64_cast_dtype: self.int64_cast_dtype,
timestamp_cast_dtype: self.timestamp_cast_dtype,
timestamp_tz_cast_dtype: self.timestamp_tz_cast_dtype,
}
}

Expand Down Expand Up @@ -293,6 +401,7 @@ impl CustomDialectBuilder {
self
}

/// Customize the dialect with a specific SQL type for Float64 casting: DOUBLE, DOUBLE PRECISION, etc.
pub fn with_float64_ast_dtype(
mut self,
float64_ast_dtype: sqlparser::ast::DataType,
Expand All @@ -301,16 +410,44 @@ impl CustomDialectBuilder {
self
}

/// Customize the dialect with a specific SQL type for Utf8 casting: VARCHAR, CHAR, etc.
pub fn with_utf8_cast_dtype(mut self, utf8_cast_dtype: ast::DataType) -> Self {
self.utf8_cast_dtype = utf8_cast_dtype;
self
}

/// Customize the dialect with a specific SQL type for LargeUtf8 casting: TEXT, CHAR, etc.
pub fn with_large_utf8_cast_dtype(
mut self,
large_utf8_cast_dtype: ast::DataType,
) -> Self {
self.large_utf8_cast_dtype = large_utf8_cast_dtype;
self
}

/// Customize the dialect with a specific date field extract style listed in `DateFieldExtractStyle`
pub fn with_date_field_extract_style(
mut self,
date_field_extract_style: DateFieldExtractStyle,
) -> Self {
self.date_field_extract_style = date_field_extract_style;
self
}

/// Customize the dialect with a specific SQL type for Int64 casting: BigInt, SIGNED, etc.
pub fn with_int64_cast_dtype(mut self, int64_cast_dtype: ast::DataType) -> Self {
self.int64_cast_dtype = int64_cast_dtype;
self
}

/// Customize the dialect with a specific SQL type for Timestamp casting: Timestamp, Datetime, etc.
pub fn with_timestamp_cast_dtype(
mut self,
timestamp_cast_dtype: ast::DataType,
timestamp_tz_cast_dtype: ast::DataType,
) -> Self {
self.timestamp_cast_dtype = timestamp_cast_dtype;
self.timestamp_tz_cast_dtype = timestamp_tz_cast_dtype;
self
}
}
Loading