Skip to content

Commit

Permalink
fix NamedStructField should be rewritten in OperatorToFunction in s…
Browse files Browse the repository at this point in the history
…ubquery regression (change `ApplyFunctionRewrites` to use TreeNode API (#10032)

* fix NamedStructField should be rewritten in OperatorToFunction in subquery

* Use TreeNode rewriter
  • Loading branch information
alamb authored Apr 12, 2024
1 parent 2def10f commit e161cd6
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 65 deletions.
99 changes: 34 additions & 65 deletions datafusion/optimizer/src/analyzer/function_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

use super::AnalyzerRule;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{DFSchema, Result};
use datafusion_expr::expr_rewriter::{rewrite_preserving_name, FunctionRewrite};

use crate::utils::NamePreserver;
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::utils::merge_schema;
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_expr::LogicalPlan;
use std::sync::Arc;

/// Analyzer rule that invokes [`FunctionRewrite`]s on expressions
Expand All @@ -37,86 +39,53 @@ impl ApplyFunctionRewrites {
pub fn new(function_rewrites: Vec<Arc<dyn FunctionRewrite + Send + Sync>>) -> Self {
Self { function_rewrites }
}
}

impl AnalyzerRule for ApplyFunctionRewrites {
fn name(&self) -> &str {
"apply_function_rewrites"
}

fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result<LogicalPlan> {
self.analyze_internal(&plan, options)
}
}

impl ApplyFunctionRewrites {
fn analyze_internal(
/// Rewrite a single plan, and all its expressions using the provided rewriters
fn rewrite_plan(
&self,
plan: &LogicalPlan,
plan: LogicalPlan,
options: &ConfigOptions,
) -> Result<LogicalPlan> {
// optimize child plans first
let new_inputs = plan
.inputs()
.iter()
.map(|p| self.analyze_internal(p, options))
.collect::<Result<Vec<_>>>()?;

) -> Result<Transformed<LogicalPlan>> {
// get schema representing all available input fields. This is used for data type
// resolution only, so order does not matter here
let mut schema = merge_schema(new_inputs.iter().collect());
let mut schema = merge_schema(plan.inputs());

if let LogicalPlan::TableScan(ts) = plan {
if let LogicalPlan::TableScan(ts) = &plan {
let source_schema = DFSchema::try_from_qualified_schema(
ts.table_name.clone(),
&ts.source.schema(),
)?;
schema.merge(&source_schema);
}

let mut expr_rewrite = OperatorToFunctionRewriter {
function_rewrites: &self.function_rewrites,
options,
schema: &schema,
};
let name_preserver = NamePreserver::new(&plan);

plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr)?;

let new_expr = plan
.expressions()
.into_iter()
.map(|expr| {
// ensure names don't change:
// https://github.com/apache/arrow-datafusion/issues/3555
rewrite_preserving_name(expr, &mut expr_rewrite)
})
.collect::<Result<Vec<_>>>()?;
// recursively transform the expression, applying the rewrites at each step
let result = expr.transform_up(&|expr| {
let mut result = Transformed::no(expr);
for rewriter in self.function_rewrites.iter() {
result = result.transform_data(|expr| {
rewriter.rewrite(expr, &schema, options)
})?;
}
Ok(result)
})?;

plan.with_new_exprs(new_expr, new_inputs)
result.map_data(|expr| original_name.restore(expr))
})
}
}
struct OperatorToFunctionRewriter<'a> {
function_rewrites: &'a [Arc<dyn FunctionRewrite + Send + Sync>],
options: &'a ConfigOptions,
schema: &'a DFSchema,
}

impl<'a> TreeNodeRewriter for OperatorToFunctionRewriter<'a> {
type Node = Expr;

fn f_up(&mut self, mut expr: Expr) -> Result<Transformed<Expr>> {
// apply transforms one by one
let mut transformed = false;
for rewriter in self.function_rewrites.iter() {
let result = rewriter.rewrite(expr, self.schema, self.options)?;
if result.transformed {
transformed = true;
}
expr = result.data
}
impl AnalyzerRule for ApplyFunctionRewrites {
fn name(&self) -> &str {
"apply_function_rewrites"
}

Ok(if transformed {
Transformed::yes(expr)
} else {
Transformed::no(expr)
})
fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result<LogicalPlan> {
plan.transform_up_with_subqueries(&|plan| self.rewrite_plan(plan, options))
.map(|res| res.data)
}
}
44 changes: 44 additions & 0 deletions datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,47 @@ pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema {
expr_utils::merge_schema(inputs)
}

/// Handles ensuring the name of rewritten expressions is not changed.
///
/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
/// expression should be preserved: `3 as "1 + 2"`
///
/// See <https://github.com/apache/arrow-datafusion/issues/3555> for details
pub struct NamePreserver {
use_alias: bool,
}

/// If the name of an expression is remembered, it will be preserved when
/// rewriting the expression
pub struct SavedName(Option<String>);

impl NamePreserver {
/// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
pub fn new(plan: &LogicalPlan) -> Self {
Self {
use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)),
}
}

pub fn save(&self, expr: &Expr) -> Result<SavedName> {
let original_name = if self.use_alias {
Some(expr.name_for_alias()?)
} else {
None
};

Ok(SavedName(original_name))
}
}

impl SavedName {
/// Ensures the name of the rewritten expression is preserved
pub fn restore(self, expr: Expr) -> Result<Expr> {
let Self(original_name) = self;
match original_name {
Some(name) => expr.alias_if_changed(name),
None => Ok(expr),
}
}
}
55 changes: 55 additions & 0 deletions datafusion/sqllogictest/test_files/subquery.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1060,3 +1060,58 @@ logical_plan
Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1)
--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a
----TableScan: t projection=[a]

###
## Ensure that operators are rewritten in subqueries
###

statement ok
create table foo(x int) as values (1);

# Show input data
query ?
select struct(1, 'b')
----
{c0: 1, c1: b}


query T
select (select struct(1, 'b')['c1']);
----
b

query T
select 'foo' || (select struct(1, 'b')['c1']);
----
foob

query I
SELECT * FROM (VALUES (1), (2))
WHERE column1 IN (SELECT struct(1, 'b')['c0']);
----
1

# also add an expression so the subquery is the output expr
query I
SELECT * FROM (VALUES (1), (2))
WHERE 1+2 = 3 AND column1 IN (SELECT struct(1, 'b')['c0']);
----
1


query I
SELECT * FROM foo
WHERE EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1);
----
1

# also add an expression so the subquery is the output expr
query I
SELECT * FROM foo
WHERE 1+2 = 3 AND EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1);
----
1


statement ok
drop table foo;

0 comments on commit e161cd6

Please sign in to comment.