diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index e08f25d3c27c..277efd5fe700 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -30,8 +30,10 @@ use super::{ BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder, }, - rewrite::normalize_union_schema, - rewrite::rewrite_plan_for_sort_on_non_projected_fields, + rewrite::{ + normalize_union_schema, rewrite_plan_for_sort_on_non_projected_fields, + subquery_alias_inner_query_and_columns, + }, utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant}, Unparser, }; @@ -687,67 +689,6 @@ impl Unparser<'_> { } } -// This logic is to work out the columns and inner query for SubqueryAlias plan for both types of -// subquery -// - `(SELECT column_a as a from table) AS A` -// - `(SELECT column_a from table) AS A (a)` -// -// A roundtrip example for table alias with columns -// -// query: SELECT id FROM (SELECT j1_id from j1) AS c (id) -// -// LogicPlan: -// Projection: c.id -// SubqueryAlias: c -// Projection: j1.j1_id AS id -// Projection: j1.j1_id -// TableScan: j1 -// -// Before introducing this logic, the unparsed query would be `SELECT c.id FROM (SELECT j1.j1_id AS -// id FROM (SELECT j1.j1_id FROM j1)) AS c`. -// The query is invalid as `j1.j1_id` is not a valid identifier in the derived table -// `(SELECT j1.j1_id FROM j1)` -// -// With this logic, the unparsed query will be: -// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)` -// -// Caveat: this won't handle the case like `select * from (select 1, 2) AS a (b, c)` -// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal and -// Column in the Projections. Once the parser side is fixed, this logic should work -fn subquery_alias_inner_query_and_columns( - subquery_alias: &datafusion_expr::SubqueryAlias, -) -> (&LogicalPlan, Vec) { - let plan: &LogicalPlan = subquery_alias.input.as_ref(); - - let LogicalPlan::Projection(outer_projections) = plan else { - return (plan, vec![]); - }; - - // check if it's projection inside projection - let LogicalPlan::Projection(inner_projection) = outer_projections.input.as_ref() - else { - return (plan, vec![]); - }; - - let mut columns: Vec = vec![]; - // check if the inner projection and outer projection have a matching pattern like - // Projection: j1.j1_id AS id - // Projection: j1.j1_id - for (i, inner_expr) in inner_projection.expr.iter().enumerate() { - let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else { - return (plan, vec![]); - }; - - if outer_alias.expr.as_ref() != inner_expr { - return (plan, vec![]); - }; - - columns.push(outer_alias.name.as_str().into()); - } - - (outer_projections.input.as_ref(), columns) -} - impl From for DataFusionError { fn from(e: BuilderError) -> Self { DataFusionError::External(Box::new(e)) diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index fba95ad48f32..f6725485f920 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -25,6 +25,7 @@ use datafusion_common::{ Result, }; use datafusion_expr::{Expr, LogicalPlan, Projection, Sort}; +use sqlparser::ast::Ident; /// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions. /// @@ -137,14 +138,25 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( let inner_exprs = inner_p .expr .iter() - .map(|f| { - if let Expr::Alias(alias) = f { + .enumerate() + .map(|(i, f)| match f { + Expr::Alias(alias) => { let a = Expr::Column(alias.name.clone().into()); map.insert(a.clone(), f.clone()); a - } else { + } + Expr::Column(_) => { + map.insert( + Expr::Column(inner_p.schema.field(i).name().into()), + f.clone(), + ); f.clone() } + _ => { + let a = Expr::Column(inner_p.schema.field(i).name().into()); + map.insert(a.clone(), f.clone()); + a + } }) .collect::>(); @@ -155,9 +167,17 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( } } - if collects.iter().collect::>() - == inner_exprs.iter().collect::>() - { + // Compare outer collects Expr::to_string with inner collected transformed values + // alias -> alias column + // column -> remain + // others, extract schema field name + let outer_collects = collects.iter().map(Expr::to_string).collect::>(); + let inner_collects = inner_exprs + .iter() + .map(Expr::to_string) + .collect::>(); + + if outer_collects == inner_collects { let mut sort = sort.clone(); let mut inner_p = inner_p.clone(); @@ -175,3 +195,80 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( None } } + +// This logic is to work out the columns and inner query for SubqueryAlias plan for both types of +// subquery +// - `(SELECT column_a as a from table) AS A` +// - `(SELECT column_a from table) AS A (a)` +// +// A roundtrip example for table alias with columns +// +// query: SELECT id FROM (SELECT j1_id from j1) AS c (id) +// +// LogicPlan: +// Projection: c.id +// SubqueryAlias: c +// Projection: j1.j1_id AS id +// Projection: j1.j1_id +// TableScan: j1 +// +// Before introducing this logic, the unparsed query would be `SELECT c.id FROM (SELECT j1.j1_id AS +// id FROM (SELECT j1.j1_id FROM j1)) AS c`. +// The query is invalid as `j1.j1_id` is not a valid identifier in the derived table +// `(SELECT j1.j1_id FROM j1)` +// +// With this logic, the unparsed query will be: +// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)` +// +// Caveat: this won't handle the case like `select * from (select 1, 2) AS a (b, c)` +// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal and +// Column in the Projections. Once the parser side is fixed, this logic should work +pub(super) fn subquery_alias_inner_query_and_columns( + subquery_alias: &datafusion_expr::SubqueryAlias, +) -> (&LogicalPlan, Vec) { + let plan: &LogicalPlan = subquery_alias.input.as_ref(); + + let LogicalPlan::Projection(outer_projections) = plan else { + return (plan, vec![]); + }; + + // check if it's projection inside projection + let Some(inner_projection) = find_projection(outer_projections.input.as_ref()) else { + return (plan, vec![]); + }; + + let mut columns: Vec = vec![]; + // check if the inner projection and outer projection have a matching pattern like + // Projection: j1.j1_id AS id + // Projection: j1.j1_id + for (i, inner_expr) in inner_projection.expr.iter().enumerate() { + let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else { + return (plan, vec![]); + }; + + // inner projection schema fields store the projection name which is used in outer + // projection expr + let inner_expr_string = match inner_expr { + Expr::Column(_) => inner_expr.to_string(), + _ => inner_projection.schema.field(i).name().clone(), + }; + + if outer_alias.expr.to_string() != inner_expr_string { + return (plan, vec![]); + }; + + columns.push(outer_alias.name.as_str().into()); + } + + (outer_projections.input.as_ref(), columns) +} + +fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> { + match logical_plan { + LogicalPlan::Projection(p) => Some(p), + LogicalPlan::Limit(p) => find_projection(p.input.as_ref()), + LogicalPlan::Distinct(p) => find_projection(p.input().as_ref()), + LogicalPlan::Sort(p) => find_projection(p.input.as_ref()), + _ => None, + } +} diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 2ac303487336..9bbdbe8dbfc9 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -373,6 +373,38 @@ fn roundtrip_statement_with_dialect() -> Result<()> { parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), }, + // Test query that has calculation in derived table with columns + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT j1_id + 1 * 3 from j1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT (j1.j1_id + (1 * 3)) FROM j1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + // Test query that has limit/distinct/order in derived table with columns + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT distinct (j1_id + 1 * 3) FROM j1 LIMIT 1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT DISTINCT (j1.j1_id + (1 * 3)) FROM j1 LIMIT 1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT j1_id + 1 FROM j1 ORDER BY j1_id DESC LIMIT 1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT (j1.j1_id + 1) FROM j1 ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT CAST((CAST(j1_id as BIGINT) + 1) as int) * 10 FROM j1 LIMIT 1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT (CAST((CAST(j1.j1_id AS BIGINT) + 1) AS INTEGER) * 10) FROM j1 LIMIT 1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT CAST(j1_id as BIGINT) + 1 FROM j1 ORDER BY j1_id LIMIT 1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT (CAST(j1.j1_id AS BIGINT) + 1) FROM j1 ORDER BY j1.j1_id ASC NULLS LAST LIMIT 1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + } ]; for query in tests {