diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 1b62c1bc05c1..6bacc1870079 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -667,8 +667,6 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_concat_internal_coercion(from_type, &LargeUtf8) } - // TODO: cast between array elements (#6558) - (List(_), from_type) | (from_type, List(_)) => Some(from_type.to_owned()), _ => None, }) } diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 14d5ddf47378..9d47299a5616 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -17,6 +17,7 @@ pub mod count_wildcard_rule; pub mod inline_table_scan; +pub mod rewrite_expr; pub mod subquery; pub mod type_coercion; @@ -37,6 +38,8 @@ use log::debug; use std::sync::Arc; use std::time::Instant; +use self::rewrite_expr::OperatorToFunction; + /// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make /// the plan valid prior to the rest of the DataFusion optimization process. /// @@ -72,6 +75,9 @@ impl Analyzer { pub fn new() -> Self { let rules: Vec> = vec![ Arc::new(InlineTableScan::new()), + // OperatorToFunction should be run before TypeCoercion, since it rewrite based on the argument types (List or Scalar), + // and TypeCoercion may cast the argument types from Scalar to List. + Arc::new(OperatorToFunction::new()), Arc::new(TypeCoercion::new()), Arc::new(CountWildcardRule::new()), ]; diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs new file mode 100644 index 000000000000..8f1c844ed062 --- /dev/null +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -0,0 +1,321 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Analyzer rule for to replace operators with function calls (e.g `||` to array_concat`) + +use std::sync::Arc; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::utils::list_ndims; +use datafusion_common::DFSchema; +use datafusion_common::DFSchemaRef; +use datafusion_common::Result; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr_rewriter::rewrite_preserving_name; +use datafusion_expr::utils::merge_schema; +use datafusion_expr::BuiltinScalarFunction; +use datafusion_expr::Operator; +use datafusion_expr::ScalarFunctionDefinition; +use datafusion_expr::{BinaryExpr, Expr, LogicalPlan}; + +use super::AnalyzerRule; + +#[derive(Default)] +pub struct OperatorToFunction {} + +impl OperatorToFunction { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for OperatorToFunction { + fn name(&self) -> &str { + "operator_to_function" + } + + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + analyze_internal(&plan) + } +} + +fn analyze_internal(plan: &LogicalPlan) -> Result { + // optimize child plans first + let new_inputs = plan + .inputs() + .iter() + .map(|p| analyze_internal(p)) + .collect::>>()?; + + // get schema representing all available input fields. This is used for data type + // resolution only, so order does not matter here + let mut schema = merge_schema(new_inputs.iter().collect()); + + if let LogicalPlan::TableScan(ts) = plan { + let source_schema = + DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?; + schema.merge(&source_schema); + } + + let mut expr_rewrite = OperatorToFunctionRewriter { + schema: Arc::new(schema), + }; + + let new_expr = plan + .expressions() + .into_iter() + .map(|expr| { + // ensure names don't change: + // https://github.com/apache/arrow-datafusion/issues/3555 + rewrite_preserving_name(expr, &mut expr_rewrite) + }) + .collect::>>()?; + + plan.with_new_exprs(new_expr, &new_inputs) +} + +pub(crate) struct OperatorToFunctionRewriter { + pub(crate) schema: DFSchemaRef, +} + +impl TreeNodeRewriter for OperatorToFunctionRewriter { + type N = Expr; + + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::BinaryExpr(BinaryExpr { + ref left, + op, + ref right, + }) => { + if let Some(fun) = rewrite_array_concat_operator_to_func_for_column( + left.as_ref(), + op, + right.as_ref(), + self.schema.as_ref(), + )? + .or_else(|| { + rewrite_array_concat_operator_to_func( + left.as_ref(), + op, + right.as_ref(), + ) + }) { + // Convert &Box -> Expr + let left = (**left).clone(); + let right = (**right).clone(); + return Ok(Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args: vec![left, right], + })); + } + + Ok(expr) + } + _ => Ok(expr), + } + } +} + +/// Summary of the logic below: +/// +/// 1) array || array -> array concat +/// +/// 2) array || scalar -> array append +/// +/// 3) scalar || array -> array prepend +/// +/// 4) (arry concat, array append, array prepend) || array -> array concat +/// +/// 5) (arry concat, array append, array prepend) || scalar -> array append +fn rewrite_array_concat_operator_to_func( + left: &Expr, + op: Operator, + right: &Expr, +) -> Option { + // Convert `Array StringConcat Array` to ScalarFunction::ArrayConcat + + if op != Operator::StringConcat { + return None; + } + + match (left, right) { + // Chain concat operator (a || b) || array, + // (arry concat, array append, array prepend) || array -> array concat + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayConcat), + // Chain concat operator (a || b) || scalar, + // (arry concat, array append, array prepend) || scalar -> array append + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + _scalar, + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + _scalar, + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + _scalar, + ) => Some(BuiltinScalarFunction::ArrayAppend), + // array || array -> array concat + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayConcat), + // array || scalar -> array append + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _left_args, + }), + _right_scalar, + ) => Some(BuiltinScalarFunction::ArrayAppend), + // scalar || array -> array prepend + ( + _left_scalar, + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayPrepend), + + _ => None, + } +} + +/// Summary of the logic below: +/// +/// 1) (arry concat, array append, array prepend) || column -> (array append, array concat) +/// +/// 2) column1 || column2 -> (array prepend, array append, array concat) +fn rewrite_array_concat_operator_to_func_for_column( + left: &Expr, + op: Operator, + right: &Expr, + schema: &DFSchema, +) -> Result> { + if op != Operator::StringConcat { + return Ok(None); + } + + match (left, right) { + // Column cases: + // 1) array_prepend/append/concat || column + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + Expr::Column(c), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + Expr::Column(c), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + Expr::Column(c), + ) => { + let d = schema.field_from_column(c)?.data_type(); + let ndim = list_ndims(d); + match ndim { + 0 => Ok(Some(BuiltinScalarFunction::ArrayAppend)), + _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)), + } + } + // 2) select column1 || column2 + (Expr::Column(c1), Expr::Column(c2)) => { + let d1 = schema.field_from_column(c1)?.data_type(); + let d2 = schema.field_from_column(c2)?.data_type(); + let ndim1 = list_ndims(d1); + let ndim2 = list_ndims(d2); + match (ndim1, ndim2) { + (0, _) => Ok(Some(BuiltinScalarFunction::ArrayPrepend)), + (_, 0) => Ok(Some(BuiltinScalarFunction::ArrayAppend)), + _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)), + } + } + _ => Ok(None), + } +} diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index c17081398cb8..8c4078dbce8c 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -20,9 +20,7 @@ mod kernels; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use crate::array_expressions::{ - array_append, array_concat, array_has_all, array_prepend, -}; +use crate::array_expressions::array_has_all; use crate::expressions::datum::{apply, apply_cmp}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::physical_expr::down_cast_any_ref; @@ -598,12 +596,7 @@ impl BinaryExpr { BitwiseXor => bitwise_xor_dyn(left, right), BitwiseShiftRight => bitwise_shift_right_dyn(left, right), BitwiseShiftLeft => bitwise_shift_left_dyn(left, right), - StringConcat => match (left_data_type, right_data_type) { - (DataType::List(_), DataType::List(_)) => array_concat(&[left, right]), - (DataType::List(_), _) => array_append(&[left, right]), - (_, DataType::List(_)) => array_prepend(&[left, right]), - _ => binary_string_array_op!(left, right, concat_elements), - }, + StringConcat => binary_string_array_op!(left, right, concat_elements), AtArrow => array_has_all(&[left, right]), ArrowAt => array_has_all(&[right, left]), } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 27351e10eb34..9fded63af3fc 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -98,11 +98,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { StackEntry::Operator(op) => { let right = eval_stack.pop().unwrap(); let left = eval_stack.pop().unwrap(); + let expr = Expr::BinaryExpr(BinaryExpr::new( Box::new(left), op, Box::new(right), )); + eval_stack.push(expr); } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 6dab3b3084a9..2726af5d4c97 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -4191,6 +4191,45 @@ select 1 || make_array(2, 3, 4), 1.0 || make_array(2.0, 3.0, 4.0), 'h' || make_a ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +# array concatenate operator with scalars #4 (mixed) +query ? +select 0 || [1,2,3] || 4 || [5] || [6,7]; +---- +[0, 1, 2, 3, 4, 5, 6, 7] + +# array concatenate operator with nd-list #5 (mixed) +query ? +select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10]; +---- +[[0, 1, 2, 3], [4, 5], [6, 7, 8], [9, 10]] + +# array concatenate operator non-valid cases +## concat 2D with scalar is not valid +query error +select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10] || 11; + +## concat scalar with 2D is not valid +query error +select 0 || [[1,2,3]]; + +# array concatenate operator with column + +statement ok +CREATE TABLE array_concat_operator_table +AS VALUES + (0, [1, 2, 2, 3], 4, [5, 6, 5]), + (-1, [4, 5, 6], 7, [8, 1, 1]) +; + +query ? +select column1 || column2 || column3 || column4 from array_concat_operator_table; +---- +[0, 1, 2, 2, 3, 4, 5, 6, 5] +[-1, 4, 5, 6, 7, 8, 1, 1] + +statement ok +drop table array_concat_operator_table; + ## array containment operator # array containment operator with scalars #1 (at arrow) diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 4583ef319b7f..2a39e3138869 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -180,6 +180,7 @@ initial_logical_plan Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c --TableScan: simple_explain_test logical_plan after inline_table_scan SAME TEXT AS ABOVE +logical_plan after operator_to_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE logical_plan after count_wildcard_rule SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE