From bffd82552e3820e215b48694b3ffa6dda6b74d09 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 23 Dec 2023 10:54:10 +0800 Subject: [PATCH 01/11] reuse function for string concat Signed-off-by: jayzhan211 --- .../physical-expr/src/expressions/binary.rs | 11 +- datafusion/sql/src/expr/mod.rs | 184 +++++++++++++++++- 2 files changed, 181 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index c17081398cb8..8c4078dbce8c 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -20,9 +20,7 @@ mod kernels; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use crate::array_expressions::{ - array_append, array_concat, array_has_all, array_prepend, -}; +use crate::array_expressions::array_has_all; use crate::expressions::datum::{apply, apply_cmp}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::physical_expr::down_cast_any_ref; @@ -598,12 +596,7 @@ impl BinaryExpr { BitwiseXor => bitwise_xor_dyn(left, right), BitwiseShiftRight => bitwise_shift_right_dyn(left, right), BitwiseShiftLeft => bitwise_shift_left_dyn(left, right), - StringConcat => match (left_data_type, right_data_type) { - (DataType::List(_), DataType::List(_)) => array_concat(&[left, right]), - (DataType::List(_), _) => array_append(&[left, right]), - (_, DataType::List(_)) => array_prepend(&[left, right]), - _ => binary_string_array_op!(left, right, concat_elements), - }, + StringConcat => binary_string_array_op!(left, right, concat_elements), AtArrow => array_has_all(&[left, right]), ArrowAt => array_has_all(&[right, left]), } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 27351e10eb34..6b63537d5874 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -37,6 +37,7 @@ use datafusion_common::{ use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::ScalarFunctionDefinition; use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast, Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast, @@ -98,11 +99,184 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { StackEntry::Operator(op) => { let right = eval_stack.pop().unwrap(); let left = eval_stack.pop().unwrap(); - let expr = Expr::BinaryExpr(BinaryExpr::new( - Box::new(left), - op, - Box::new(right), - )); + + // Summary of the logic below: + // array || array -> array concat + // array || scalar -> array append + // scalar || array -> array prepend + // (arry concat, array append, array prepend) || array -> array concat + // (arry concat, array append, array prepend) || scalar -> array append + + // Convert `Array StringConcat Array` to ScalarFunction::ArrayConcat + let expr = match (op, &left, &right) { + // Chain concat operator (a || b) || array, + // (arry concat, array append, array prepend) || array -> array concat + ( + Operator::StringConcat, + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::ArrayConcat, + ), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::MakeArray, + ), + args: _right_args, + }), + ) + | ( + Operator::StringConcat, + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::ArrayAppend, + ), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::MakeArray, + ), + args: _right_args, + }), + ) + | ( + Operator::StringConcat, + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::ArrayPrepend, + ), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::MakeArray, + ), + args: _right_args, + }), + ) => { + let args = vec![left.clone(), right.clone()]; + Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::ArrayConcat, + args, + )) + } + // Chain concat operator (a || b) || scalar, + // (arry concat, array append, array prepend) || scalar -> array append + ( + Operator::StringConcat, + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::ArrayConcat, + ), + args: _left_args, + }), + _, + ) + | ( + Operator::StringConcat, + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::ArrayAppend, + ), + args: _left_args, + }), + _, + ) + | ( + Operator::StringConcat, + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::ArrayPrepend, + ), + args: _left_args, + }), + _, + ) => { + let args = vec![left.clone(), right.clone()]; + Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::ArrayAppend, + args, + )) + } + // array || array -> array concat + ( + Operator::StringConcat, + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::MakeArray, + ), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::MakeArray, + ), + args: _right_args, + }), + ) => { + let args = vec![left.clone(), right.clone()]; + Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::ArrayConcat, + args, + )) + } + // array || scalar -> array append + ( + Operator::StringConcat, + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::MakeArray, + ), + args: _left_args, + }), + _, + ) => { + let args = vec![left.clone(), right.clone()]; + Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::ArrayAppend, + args, + )) + } + // scalar || array -> array prepend + ( + Operator::StringConcat, + _, + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::MakeArray, + ), + args: _right_args, + }), + ) => { + let args = vec![left.clone(), right.clone()]; + Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::ArrayPrepend, + args, + )) + } + // Any other cases fall back to the default Binary Operation + _ => Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + op, + Box::new(right), + )), + }; + eval_stack.push(expr); } } From e7cfcce64825b8d001c915652486f2ae0c53df85 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 23 Dec 2023 11:39:31 +0800 Subject: [PATCH 02/11] remove casting in string concat Signed-off-by: jayzhan211 --- datafusion/expr/src/type_coercion/binary.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 1b62c1bc05c1..6bacc1870079 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -667,8 +667,6 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_concat_internal_coercion(from_type, &LargeUtf8) } - // TODO: cast between array elements (#6558) - (List(_), from_type) | (from_type, List(_)) => Some(from_type.to_owned()), _ => None, }) } From 0c18f78329b774ab4ab540076035735ee5a223c3 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Thu, 28 Dec 2023 09:07:03 +0800 Subject: [PATCH 03/11] add test Signed-off-by: jayzhan211 --- datafusion/sqllogictest/test_files/array.slt | 22 ++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 6dab3b3084a9..a21867ecdd01 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -4191,6 +4191,28 @@ select 1 || make_array(2, 3, 4), 1.0 || make_array(2.0, 3.0, 4.0), 'h' || make_a ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +# array concatenate operator with scalars #4 (mixed) +query ? +select 0 || [1,2,3] || 4 || [5] || [6,7]; +---- +[0, 1, 2, 3, 4, 5, 6, 7] + +# array concatenate operator with nd-list #5 (mixed) +query ? +select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10]; +---- +[[0, 1, 2, 3], [4, 5], [6, 7, 8], [9, 10]] + +# array concatenate operator non-valid cases +## concat 2D with scalar is not valid +query error +select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10] || 11; + +## concat scalar with 2D is not valid +query error +select 0 || [[1,2,3]]; + + ## array containment operator # array containment operator with scalars #1 (at arrow) From 6201176062369ec8f6263bf816b76a3d8c827a87 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 2 Jan 2024 20:48:58 +0800 Subject: [PATCH 04/11] operator to function rewrite Signed-off-by: jayzhan211 --- datafusion/optimizer/src/analyzer/mod.rs | 6 + .../optimizer/src/analyzer/rewrite_expr.rs | 232 ++++++++++++++++++ datafusion/sql/src/expr/mod.rs | 182 +------------- 3 files changed, 243 insertions(+), 177 deletions(-) create mode 100644 datafusion/optimizer/src/analyzer/rewrite_expr.rs diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 14d5ddf47378..9d47299a5616 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -17,6 +17,7 @@ pub mod count_wildcard_rule; pub mod inline_table_scan; +pub mod rewrite_expr; pub mod subquery; pub mod type_coercion; @@ -37,6 +38,8 @@ use log::debug; use std::sync::Arc; use std::time::Instant; +use self::rewrite_expr::OperatorToFunction; + /// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make /// the plan valid prior to the rest of the DataFusion optimization process. /// @@ -72,6 +75,9 @@ impl Analyzer { pub fn new() -> Self { let rules: Vec> = vec![ Arc::new(InlineTableScan::new()), + // OperatorToFunction should be run before TypeCoercion, since it rewrite based on the argument types (List or Scalar), + // and TypeCoercion may cast the argument types from Scalar to List. + Arc::new(OperatorToFunction::new()), Arc::new(TypeCoercion::new()), Arc::new(CountWildcardRule::new()), ]; diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs new file mode 100644 index 000000000000..b7799fbd2920 --- /dev/null +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule for expression rewrite + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::Result; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::BuiltinScalarFunction; +use datafusion_expr::Operator; +use datafusion_expr::Projection; +use datafusion_expr::ScalarFunctionDefinition; +use datafusion_expr::{BinaryExpr, Expr, LogicalPlan}; + +use super::AnalyzerRule; + +#[derive(Default)] +pub struct OperatorToFunction {} + +impl OperatorToFunction { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for OperatorToFunction { + fn name(&self) -> &str { + "operator_to_function" + } + + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + analyze_internal(plan) + } +} + +fn analyze_internal(plan: LogicalPlan) -> Result { + // OperatorToFunction is only applied to Projection + match plan { + LogicalPlan::Projection(_) => {} + _ => { + return Ok(plan); + } + } + + let mut expr_rewriter = OperatorToFunctionRewriter {}; + + let new_expr = plan + .expressions() + .into_iter() + .map(|expr| expr.rewrite(&mut expr_rewriter)) + .collect::>>()?; + + // Not found cases that inputs more than one + assert_eq!(plan.inputs().len(), 1); + let input = plan.inputs()[0]; + + Ok(LogicalPlan::Projection(Projection::try_new( + new_expr, + input.to_owned().into(), + )?)) +} + +pub(crate) struct OperatorToFunctionRewriter {} + +impl TreeNodeRewriter for OperatorToFunctionRewriter { + type N = Expr; + + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::BinaryExpr(BinaryExpr { + ref left, + op, + ref right, + }) => { + if let Some(fun) = rewrite_array_concat_operator_to_func( + left.as_ref(), + op, + right.as_ref(), + ) { + // Convert &Box -> Expr + let left = (**left).clone(); + let right = (**right).clone(); + return Ok(Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args: vec![left, right], + })); + } + Ok(expr) + } + _ => Ok(expr), + } + } +} + +/// Summary of the logic below: +/// +/// array || array -> array concat +/// +/// array || scalar -> array append +/// +/// scalar || array -> array prepend +/// +/// (arry concat, array append, array prepend) || array -> array concat +/// +/// (arry concat, array append, array prepend) || scalar -> array append +fn rewrite_array_concat_operator_to_func( + left: &Expr, + op: Operator, + right: &Expr, +) -> Option { + // Convert `Array StringConcat Array` to ScalarFunction::ArrayConcat + + if op != Operator::StringConcat { + return None; + } + + match (left, right) { + // Chain concat operator (a || b) || array, + // (arry concat, array append, array prepend) || array -> array concat + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayConcat), + // Chain concat operator (a || b) || scalar, + // (arry concat, array append, array prepend) || scalar -> array append + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + _scalar, + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + _scalar, + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + _scalar, + ) => Some(BuiltinScalarFunction::ArrayAppend), + // array || array -> array concat + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayConcat), + // array || scalar -> array append + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _left_args, + }), + _right_scalar, + ) => Some(BuiltinScalarFunction::ArrayAppend), + // scalar || array -> array prepend + ( + _left_scalar, + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayPrepend), + + _ => None, + } +} diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 6b63537d5874..9fded63af3fc 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -37,7 +37,6 @@ use datafusion_common::{ use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::ScalarFunctionDefinition; use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast, Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast, @@ -100,182 +99,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let right = eval_stack.pop().unwrap(); let left = eval_stack.pop().unwrap(); - // Summary of the logic below: - // array || array -> array concat - // array || scalar -> array append - // scalar || array -> array prepend - // (arry concat, array append, array prepend) || array -> array concat - // (arry concat, array append, array prepend) || scalar -> array append - - // Convert `Array StringConcat Array` to ScalarFunction::ArrayConcat - let expr = match (op, &left, &right) { - // Chain concat operator (a || b) || array, - // (arry concat, array append, array prepend) || array -> array concat - ( - Operator::StringConcat, - Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::ArrayConcat, - ), - args: _left_args, - }), - Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::MakeArray, - ), - args: _right_args, - }), - ) - | ( - Operator::StringConcat, - Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::ArrayAppend, - ), - args: _left_args, - }), - Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::MakeArray, - ), - args: _right_args, - }), - ) - | ( - Operator::StringConcat, - Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::ArrayPrepend, - ), - args: _left_args, - }), - Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::MakeArray, - ), - args: _right_args, - }), - ) => { - let args = vec![left.clone(), right.clone()]; - Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::ArrayConcat, - args, - )) - } - // Chain concat operator (a || b) || scalar, - // (arry concat, array append, array prepend) || scalar -> array append - ( - Operator::StringConcat, - Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::ArrayConcat, - ), - args: _left_args, - }), - _, - ) - | ( - Operator::StringConcat, - Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::ArrayAppend, - ), - args: _left_args, - }), - _, - ) - | ( - Operator::StringConcat, - Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::ArrayPrepend, - ), - args: _left_args, - }), - _, - ) => { - let args = vec![left.clone(), right.clone()]; - Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::ArrayAppend, - args, - )) - } - // array || array -> array concat - ( - Operator::StringConcat, - Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::MakeArray, - ), - args: _left_args, - }), - Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::MakeArray, - ), - args: _right_args, - }), - ) => { - let args = vec![left.clone(), right.clone()]; - Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::ArrayConcat, - args, - )) - } - // array || scalar -> array append - ( - Operator::StringConcat, - Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::MakeArray, - ), - args: _left_args, - }), - _, - ) => { - let args = vec![left.clone(), right.clone()]; - Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::ArrayAppend, - args, - )) - } - // scalar || array -> array prepend - ( - Operator::StringConcat, - _, - Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::MakeArray, - ), - args: _right_args, - }), - ) => { - let args = vec![left.clone(), right.clone()]; - Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::ArrayPrepend, - args, - )) - } - // Any other cases fall back to the default Binary Operation - _ => Expr::BinaryExpr(BinaryExpr::new( - Box::new(left), - op, - Box::new(right), - )), - }; + let expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + op, + Box::new(right), + )); eval_stack.push(expr); } From 69df2b6dfa01d756cd267c58b3025fa6c04b3c67 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 2 Jan 2024 21:02:54 +0800 Subject: [PATCH 05/11] fix explain Signed-off-by: jayzhan211 --- datafusion/sqllogictest/test_files/explain.slt | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 4583ef319b7f..2a39e3138869 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -180,6 +180,7 @@ initial_logical_plan Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c --TableScan: simple_explain_test logical_plan after inline_table_scan SAME TEXT AS ABOVE +logical_plan after operator_to_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE logical_plan after count_wildcard_rule SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE From 77386b8542b6d29d5614b47efa5ed5b0bd65bbb6 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 3 Jan 2024 07:34:27 +0800 Subject: [PATCH 06/11] add more test Signed-off-by: jayzhan211 --- datafusion/sqllogictest/test_files/array.slt | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index a21867ecdd01..2726af5d4c97 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -4212,6 +4212,23 @@ select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10] || 11; query error select 0 || [[1,2,3]]; +# array concatenate operator with column + +statement ok +CREATE TABLE array_concat_operator_table +AS VALUES + (0, [1, 2, 2, 3], 4, [5, 6, 5]), + (-1, [4, 5, 6], 7, [8, 1, 1]) +; + +query ? +select column1 || column2 || column3 || column4 from array_concat_operator_table; +---- +[0, 1, 2, 2, 3, 4, 5, 6, 5] +[-1, 4, 5, 6, 7, 8, 1, 1] + +statement ok +drop table array_concat_operator_table; ## array containment operator From a90539488e13879230d1ea9dcf393ae4c9e47522 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 3 Jan 2024 09:30:23 +0800 Subject: [PATCH 07/11] add column cases Signed-off-by: jayzhan211 --- .../optimizer/src/analyzer/rewrite_expr.rs | 153 +++++++++++++++--- 1 file changed, 128 insertions(+), 25 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index b7799fbd2920..e64ed66f2306 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -17,14 +17,21 @@ //! Optimizer rule for expression rewrite +use std::sync::Arc; + +use arrow::datatypes::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::TreeNode; use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::utils::list_ndims; +use datafusion_common::Column; +use datafusion_common::DFSchema; +use datafusion_common::DFSchemaRef; use datafusion_common::Result; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::utils::merge_schema; use datafusion_expr::BuiltinScalarFunction; use datafusion_expr::Operator; -use datafusion_expr::Projection; use datafusion_expr::ScalarFunctionDefinition; use datafusion_expr::{BinaryExpr, Expr, LogicalPlan}; @@ -45,20 +52,31 @@ impl AnalyzerRule for OperatorToFunction { } fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - analyze_internal(plan) + analyze_internal(&plan) } } -fn analyze_internal(plan: LogicalPlan) -> Result { - // OperatorToFunction is only applied to Projection - match plan { - LogicalPlan::Projection(_) => {} - _ => { - return Ok(plan); - } +fn analyze_internal(plan: &LogicalPlan) -> Result { + // optimize child plans first + let new_inputs = plan + .inputs() + .iter() + .map(|p| analyze_internal(p)) + .collect::>>()?; + + // get schema representing all available input fields. This is used for data type + // resolution only, so order does not matter here + let mut schema = merge_schema(new_inputs.iter().collect()); + + if let LogicalPlan::TableScan(ts) = plan { + let source_schema = + DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?; + schema.merge(&source_schema); } - let mut expr_rewriter = OperatorToFunctionRewriter {}; + let mut expr_rewriter = OperatorToFunctionRewriter { + schema: Arc::new(schema), + }; let new_expr = plan .expressions() @@ -66,17 +84,28 @@ fn analyze_internal(plan: LogicalPlan) -> Result { .map(|expr| expr.rewrite(&mut expr_rewriter)) .collect::>>()?; - // Not found cases that inputs more than one - assert_eq!(plan.inputs().len(), 1); - let input = plan.inputs()[0]; + plan.with_new_exprs(new_expr, &new_inputs) + + // match &plan { + // LogicalPlan::Projection(_) => Ok(LogicalPlan::Projection(Projection::try_new( + // new_expr, + // Arc::new(new_inputs[0].clone()), + // )?)), + // _ => plan.with_new_exprs(new_expr, &new_inputs), + // } + // // Not found cases that inputs more than one + // assert_eq!(plan.inputs().len(), 1); + // let input = plan.inputs()[0]; - Ok(LogicalPlan::Projection(Projection::try_new( - new_expr, - input.to_owned().into(), - )?)) + // Ok(LogicalPlan::Projection(Projection::try_new( + // new_expr, + // input.to_owned().into(), + // )?)) } -pub(crate) struct OperatorToFunctionRewriter {} +pub(crate) struct OperatorToFunctionRewriter { + pub(crate) schema: DFSchemaRef, +} impl TreeNodeRewriter for OperatorToFunctionRewriter { type N = Expr; @@ -88,11 +117,19 @@ impl TreeNodeRewriter for OperatorToFunctionRewriter { op, ref right, }) => { - if let Some(fun) = rewrite_array_concat_operator_to_func( + if let Some(fun) = rewrite_array_concat_operator_to_func_for_column( left.as_ref(), op, right.as_ref(), - ) { + self.schema.as_ref(), + )? + .or_else(|| { + rewrite_array_concat_operator_to_func( + left.as_ref(), + op, + right.as_ref(), + ) + }) { // Convert &Box -> Expr let left = (**left).clone(); let right = (**right).clone(); @@ -101,6 +138,7 @@ impl TreeNodeRewriter for OperatorToFunctionRewriter { args: vec![left, right], })); } + Ok(expr) } _ => Ok(expr), @@ -110,15 +148,15 @@ impl TreeNodeRewriter for OperatorToFunctionRewriter { /// Summary of the logic below: /// -/// array || array -> array concat +/// 1) array || array -> array concat /// -/// array || scalar -> array append +/// 2) array || scalar -> array append /// -/// scalar || array -> array prepend +/// 3) scalar || array -> array prepend /// -/// (arry concat, array append, array prepend) || array -> array concat +/// 4) (arry concat, array append, array prepend) || array -> array concat /// -/// (arry concat, array append, array prepend) || scalar -> array append +/// 5) (arry concat, array append, array prepend) || scalar -> array append fn rewrite_array_concat_operator_to_func( left: &Expr, op: Operator, @@ -230,3 +268,68 @@ fn rewrite_array_concat_operator_to_func( _ => None, } } + +/// Summary of the logic below: +/// +/// 1) (arry concat, array append, array prepend) || column -> (array append, array concat) +/// +/// 2) column1 || column2 -> (array prepend, array append, array concat) +fn rewrite_array_concat_operator_to_func_for_column( + left: &Expr, + op: Operator, + right: &Expr, + schema: &DFSchema, +) -> Result> { + if op != Operator::StringConcat { + return Ok(None); + } + + match (left, right) { + // Column cases: + // 1) array_prepend/append/concat || column + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + Expr::Column(c), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + Expr::Column(c), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + Expr::Column(c), + ) => { + let d = schema.field_from_column(c)?.data_type(); + let ndim = list_ndims(d); + match ndim { + 0 => Ok(Some(BuiltinScalarFunction::ArrayAppend)), + _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)), + } + } + // 2) select column1 || column2 + (Expr::Column(c1), Expr::Column(c2)) => { + let d0 = schema.field_from_column(c1)?.data_type(); + let d1 = schema.field_from_column(c2)?.data_type(); + let ndim0 = list_ndims(d0); + let ndim1 = list_ndims(d1); + match (ndim0, ndim1) { + (0, _) => Ok(Some(BuiltinScalarFunction::ArrayPrepend)), + (_, 0) => Ok(Some(BuiltinScalarFunction::ArrayAppend)), + _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)), + } + } + _ => Ok(None), + } +} From e4c83ab82a64f60cb12beb2c24d89a4c5b2243fd Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 3 Jan 2024 20:03:47 +0800 Subject: [PATCH 08/11] cleanup Signed-off-by: jayzhan211 --- .../optimizer/src/analyzer/rewrite_expr.rs | 20 +------------------ 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index e64ed66f2306..6360a27322a6 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -19,12 +19,10 @@ use std::sync::Arc; -use arrow::datatypes::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::TreeNode; use datafusion_common::tree_node::TreeNodeRewriter; use datafusion_common::utils::list_ndims; -use datafusion_common::Column; use datafusion_common::DFSchema; use datafusion_common::DFSchemaRef; use datafusion_common::Result; @@ -77,7 +75,7 @@ fn analyze_internal(plan: &LogicalPlan) -> Result { let mut expr_rewriter = OperatorToFunctionRewriter { schema: Arc::new(schema), }; - + let new_expr = plan .expressions() .into_iter() @@ -85,22 +83,6 @@ fn analyze_internal(plan: &LogicalPlan) -> Result { .collect::>>()?; plan.with_new_exprs(new_expr, &new_inputs) - - // match &plan { - // LogicalPlan::Projection(_) => Ok(LogicalPlan::Projection(Projection::try_new( - // new_expr, - // Arc::new(new_inputs[0].clone()), - // )?)), - // _ => plan.with_new_exprs(new_expr, &new_inputs), - // } - // // Not found cases that inputs more than one - // assert_eq!(plan.inputs().len(), 1); - // let input = plan.inputs()[0]; - - // Ok(LogicalPlan::Projection(Projection::try_new( - // new_expr, - // input.to_owned().into(), - // )?)) } pub(crate) struct OperatorToFunctionRewriter { From 1ef0e5018abe91f70b321c69822ec8cd27f1890b Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 3 Jan 2024 20:23:52 +0800 Subject: [PATCH 09/11] presever name Signed-off-by: jayzhan211 --- datafusion/optimizer/src/analyzer/rewrite_expr.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index 6360a27322a6..1ed0da24c43e 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -20,13 +20,13 @@ use std::sync::Arc; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::TreeNode; use datafusion_common::tree_node::TreeNodeRewriter; use datafusion_common::utils::list_ndims; use datafusion_common::DFSchema; use datafusion_common::DFSchemaRef; use datafusion_common::Result; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::merge_schema; use datafusion_expr::BuiltinScalarFunction; use datafusion_expr::Operator; @@ -72,14 +72,18 @@ fn analyze_internal(plan: &LogicalPlan) -> Result { schema.merge(&source_schema); } - let mut expr_rewriter = OperatorToFunctionRewriter { + let mut expr_rewrite = OperatorToFunctionRewriter { schema: Arc::new(schema), }; - + let new_expr = plan .expressions() .into_iter() - .map(|expr| expr.rewrite(&mut expr_rewriter)) + .map(|expr| { + // ensure names don't change: + // https://github.com/apache/arrow-datafusion/issues/3555 + rewrite_preserving_name(expr, &mut expr_rewrite) + }) .collect::>>()?; plan.with_new_exprs(new_expr, &new_inputs) From 111299bad6ea6514f13b4e3831c4999374e3c60f Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Wed, 3 Jan 2024 20:28:03 +0800 Subject: [PATCH 10/11] Update datafusion/optimizer/src/analyzer/rewrite_expr.rs Co-authored-by: Andrew Lamb --- datafusion/optimizer/src/analyzer/rewrite_expr.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index 1ed0da24c43e..afbad1828f64 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule for expression rewrite +//! Analyzer rule for to replace operators with function calls (e.g `||` to array_concat`) use std::sync::Arc; From 6a0d78212b092715f3d5749d3973035078e17fed Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 3 Jan 2024 20:29:56 +0800 Subject: [PATCH 11/11] rename Signed-off-by: jayzhan211 --- datafusion/optimizer/src/analyzer/rewrite_expr.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index afbad1828f64..8f1c844ed062 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -306,11 +306,11 @@ fn rewrite_array_concat_operator_to_func_for_column( } // 2) select column1 || column2 (Expr::Column(c1), Expr::Column(c2)) => { - let d0 = schema.field_from_column(c1)?.data_type(); - let d1 = schema.field_from_column(c2)?.data_type(); - let ndim0 = list_ndims(d0); + let d1 = schema.field_from_column(c1)?.data_type(); + let d2 = schema.field_from_column(c2)?.data_type(); let ndim1 = list_ndims(d1); - match (ndim0, ndim1) { + let ndim2 = list_ndims(d2); + match (ndim1, ndim2) { (0, _) => Ok(Some(BuiltinScalarFunction::ArrayPrepend)), (_, 0) => Ok(Some(BuiltinScalarFunction::ArrayAppend)), _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)),