From 945902dd5d440bdc360cab60ef31cd0c3bceec41 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Sun, 25 Aug 2024 19:23:41 +0800 Subject: [PATCH] fix: preserve expression names when replacing placeholders (#12126) * fix: preserve expression names when replacing placeholders * Add comment --- datafusion/core/src/dataframe/mod.rs | 30 ++++++++++++- datafusion/expr/src/expr_rewriter/mod.rs | 51 ++++++++++++++++++++++ datafusion/expr/src/logical_plan/plan.rs | 23 ++++++---- datafusion/optimizer/src/utils.rs | 55 ++---------------------- datafusion/sql/tests/sql_integration.rs | 15 ++++--- 5 files changed, 107 insertions(+), 67 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 42203e5fe84e..a38e7f45a6f1 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1707,7 +1707,7 @@ mod tests { use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; use arrow::array::{self, Int32Array}; - use datafusion_common::{Constraint, Constraints, ScalarValue}; + use datafusion_common::{assert_batches_eq, Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ @@ -3699,4 +3699,32 @@ mod tests { assert!(result.is_err()); Ok(()) } + + // Test issue: https://github.com/apache/datafusion/issues/12065 + #[tokio::test] + async fn filtered_aggr_with_param_values() -> Result<()> { + let cfg = SessionConfig::new() + .set("datafusion.sql_parser.dialect", "PostgreSQL".into()); + let ctx = SessionContext::new_with_config(cfg); + register_aggregate_csv(&ctx, "table1").await?; + + let df = ctx + .sql("select count (c2) filter (where c3 > $1) from table1") + .await? + .with_param_values(ParamValues::List(vec![ScalarValue::from(10u64)])); + + let df_results = df?.collect().await?; + assert_batches_eq!( + &[ + "+------------------------------------------------+", + "| count(table1.c2) FILTER (WHERE table1.c3 > $1) |", + "+------------------------------------------------+", + "| 54 |", + "+------------------------------------------------+", + ], + &df_results + ); + + Ok(()) + } } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index c26970cb053a..768c4aabc840 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -279,6 +279,57 @@ where expr.alias_if_changed(original_name) } +/// Handles ensuring the name of rewritten expressions is not changed. +/// +/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the +/// expression should be preserved: `3 as "1 + 2"` +/// +/// See for details +pub struct NamePreserver { + use_alias: bool, +} + +/// If the name of an expression is remembered, it will be preserved when +/// rewriting the expression +pub struct SavedName(Option); + +impl NamePreserver { + /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan + pub fn new(plan: &LogicalPlan) -> Self { + Self { + use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)), + } + } + + /// Create a new NamePreserver for rewriting the `expr`s in `Projection` + /// + /// This will use aliases + pub fn new_for_projection() -> Self { + Self { use_alias: true } + } + + pub fn save(&self, expr: &Expr) -> Result { + let original_name = if self.use_alias { + Some(expr.name_for_alias()?) + } else { + None + }; + + Ok(SavedName(original_name)) + } +} + +impl SavedName { + /// Ensures the name of the rewritten expression is preserved + pub fn restore(self, expr: Expr) -> Result { + let Self(original_name) = self; + match original_name { + Some(name) => expr.alias_if_changed(name), + None => Ok(expr), + } + } +} + #[cfg(test)] mod test { use std::ops::Add; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ca7d04b9b03e..3ede7f25b753 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -26,7 +26,7 @@ use super::dml::CopyTo; use super::DdlStatement; use crate::builder::{change_redundant_column, unnest_with_options}; use crate::expr::{Placeholder, Sort as SortExpr, WindowFunction}; -use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols}; +use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols, NamePreserver}; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; @@ -1339,15 +1339,20 @@ impl LogicalPlan { ) -> Result { self.transform_up_with_subqueries(|plan| { let schema = Arc::clone(plan.schema()); + let name_preserver = NamePreserver::new(&plan); plan.map_expressions(|e| { - e.infer_placeholder_types(&schema)?.transform_up(|e| { - if let Expr::Placeholder(Placeholder { id, .. }) = e { - let value = param_values.get_placeholders_with_values(&id)?; - Ok(Transformed::yes(Expr::Literal(value))) - } else { - Ok(Transformed::no(e)) - } - }) + let original_name = name_preserver.save(&e)?; + let transformed_expr = + e.infer_placeholder_types(&schema)?.transform_up(|e| { + if let Expr::Placeholder(Placeholder { id, .. }) = e { + let value = param_values.get_placeholders_with_values(&id)?; + Ok(Transformed::yes(Expr::Literal(value))) + } else { + Ok(Transformed::no(e)) + } + })?; + // Preserve name to avoid breaking column references to this expression + transformed_expr.map_data(|expr| original_name.restore(expr)) }) }) .map(|res| res.data) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 05b1744d90c5..45cef55bf272 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -28,6 +28,10 @@ use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator}; use log::{debug, trace}; +/// Re-export of `NamesPreserver` for backwards compatibility, +/// as it was initially placed here and then moved elsewhere. +pub use datafusion_expr::expr_rewriter::NamePreserver; + /// Convenience rule for writing optimizers: recursively invoke /// optimize on plan's children and then return a node of the same /// type. Useful for optimizer rules which want to leave the type @@ -294,54 +298,3 @@ pub fn only_or_err(slice: &[T]) -> Result<&T> { pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { expr_utils::merge_schema(inputs) } - -/// Handles ensuring the name of rewritten expressions is not changed. -/// -/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the -/// expression should be preserved: `3 as "1 + 2"` -/// -/// See for details -pub struct NamePreserver { - use_alias: bool, -} - -/// If the name of an expression is remembered, it will be preserved when -/// rewriting the expression -pub struct SavedName(Option); - -impl NamePreserver { - /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan - pub fn new(plan: &LogicalPlan) -> Self { - Self { - use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)), - } - } - - /// Create a new NamePreserver for rewriting the `expr`s in `Projection` - /// - /// This will use aliases - pub fn new_for_projection() -> Self { - Self { use_alias: true } - } - - pub fn save(&self, expr: &Expr) -> Result { - let original_name = if self.use_alias { - Some(expr.name_for_alias()?) - } else { - None - }; - - Ok(SavedName(original_name)) - } -} - -impl SavedName { - /// Ensures the name of the rewritten expression is preserved - pub fn restore(self, expr: Expr) -> Result { - let Self(original_name) = self; - match original_name { - Some(name) => expr.alias_if_changed(name), - None => Ok(expr), - } - } -} diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 5685e09c9c9f..5a203703e967 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -3813,7 +3813,7 @@ fn test_prepare_statement_to_plan_params_as_constants() { /////////////////// // replace params with values let param_values = vec![ScalarValue::Int32(Some(10))]; - let expected_plan = "Projection: Int32(10)\n EmptyRelation"; + let expected_plan = "Projection: Int32(10) AS $1\n EmptyRelation"; prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); @@ -3829,7 +3829,8 @@ fn test_prepare_statement_to_plan_params_as_constants() { /////////////////// // replace params with values let param_values = vec![ScalarValue::Int32(Some(10))]; - let expected_plan = "Projection: Int64(1) + Int32(10)\n EmptyRelation"; + let expected_plan = + "Projection: Int64(1) + Int32(10) AS Int64(1) + $1\n EmptyRelation"; prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); @@ -3848,7 +3849,9 @@ fn test_prepare_statement_to_plan_params_as_constants() { ScalarValue::Int32(Some(10)), ScalarValue::Float64(Some(10.0)), ]; - let expected_plan = "Projection: Int64(1) + Int32(10) + Float64(10)\n EmptyRelation"; + let expected_plan = + "Projection: Int64(1) + Int32(10) + Float64(10) AS Int64(1) + $1 + $2\ + \n EmptyRelation"; prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } @@ -4063,7 +4066,7 @@ fn test_prepare_statement_insert_infer() { \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ - \n Values: (UInt32(1), Utf8(\"Alan\"), Utf8(\"Turing\"))"; + \n Values: (UInt32(1) AS $1, Utf8(\"Alan\") AS $2, Utf8(\"Turing\") AS $3)"; let plan = plan.replace_params_with_values(¶m_values).unwrap(); prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); @@ -4144,7 +4147,7 @@ fn test_prepare_statement_to_plan_multi_params() { ScalarValue::from("xyz"), ]; let expected_plan = - "Projection: person.id, person.age, Utf8(\"xyz\")\ + "Projection: person.id, person.age, Utf8(\"xyz\") AS $6\ \n Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary > Float64(100) AND person.salary < Float64(200) OR person.first_name < Utf8(\"abc\")\ \n TableScan: person"; @@ -4213,7 +4216,7 @@ fn test_prepare_statement_to_plan_value_list() { let expected_plan = "Projection: *\ \n SubqueryAlias: t\ \n Projection: column1 AS num, column2 AS letter\ - \n Values: (Int64(1), Utf8(\"a\")), (Int64(2), Utf8(\"b\"))"; + \n Values: (Int64(1), Utf8(\"a\") AS $1), (Int64(2), Utf8(\"b\") AS $2)"; prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); }