From c8dc3a4269144d27752f6d6b71dee5ec0b5cf85a Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 20 Oct 2023 11:12:58 +0300 Subject: [PATCH 01/15] Initial commit --- datafusion/common/src/dfschema.rs | 8 +++ datafusion/core/tests/sql/group_by.rs | 14 ++-- datafusion/expr/src/built_in_function.rs | 65 +++++++++++++++++-- datafusion/expr/src/logical_plan/plan.rs | 44 ++++++++++++- datafusion/expr/src/type_coercion/binary.rs | 2 +- .../optimizer/src/common_subexpr_eliminate.rs | 14 +--- datafusion/optimizer/src/merge_projection.rs | 3 +- .../optimizer/src/push_down_projection.rs | 12 +--- .../src/single_distinct_to_groupby.rs | 52 +++++++++------ .../physical-expr/src/array_expressions.rs | 13 ++-- datafusion/sqllogictest/test_files/array.slt | 6 +- .../sqllogictest/test_files/tpch/q16.slt.part | 20 +++--- 12 files changed, 179 insertions(+), 74 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index b1aee41978c2..e16acbfedc81 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -444,6 +444,14 @@ impl DFSchema { .zip(iter2) .all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_semantically_equal(f1, f2)) } + ( + DataType::Decimal128(_l_precision, _l_scale), + DataType::Decimal128(_r_precision, _r_scale), + ) => true, + ( + DataType::Decimal256(_l_precision, _l_scale), + DataType::Decimal256(_r_precision, _r_scale), + ) => true, _ => dt1 == dt2, } } diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index 7c7703b69683..58f0ac21d951 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -231,13 +231,13 @@ async fn group_by_dictionary() { .expect("ran plan correctly"); let expected = [ - "+-------+------------------------+", - "| t.val | COUNT(DISTINCT t.dict) |", - "+-------+------------------------+", - "| 1 | 2 |", - "| 2 | 2 |", - "| 4 | 1 |", - "+-------+------------------------+", + "+-----+------------------------+", + "| val | COUNT(DISTINCT t.dict) |", + "+-----+------------------------+", + "| 1 | 2 |", + "| 2 | 2 |", + "| 4 | 1 |", + "+-----+------------------------+", ]; assert_batches_sorted_eq!(expected, &results); } diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 350067a42186..fbe69deddf1c 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -28,6 +28,7 @@ use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, }; +use std::cmp::Ordering; use std::collections::HashMap; use std::fmt; use std::str::FromStr; @@ -315,6 +316,56 @@ fn function_to_name() -> &'static HashMap { }) } +// TODO: Enrich this implementation +fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { + Ok(match (lhs, rhs) { + (DataType::Null, _) => rhs.clone(), + (_, DataType::Null) => lhs.clone(), + // Int + ( + DataType::Int8, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, + ) => rhs.clone(), + (DataType::Int16, DataType::Int16 | DataType::Int32 | DataType::Int64) => { + rhs.clone() + } + (DataType::Int32, DataType::Int32 | DataType::Int64) => rhs.clone(), + (DataType::Int16 | DataType::Int32 | DataType::Int64, DataType::Int8) => { + lhs.clone() + } + (DataType::Int32 | DataType::Int64, DataType::Int16) => lhs.clone(), + (DataType::Int64, DataType::Int32) => lhs.clone(), + (DataType::Int64, DataType::Int64) => DataType::Int64, + // Float + ( + DataType::Float16, + DataType::Float16 | DataType::Float32 | DataType::Float64, + ) => rhs.clone(), + (DataType::Float32, DataType::Float32 | DataType::Float64) => rhs.clone(), + (DataType::Float32 | DataType::Float64, DataType::Float16) => lhs.clone(), + (DataType::Float64, DataType::Float32) => lhs.clone(), + (DataType::Float64, DataType::Float64) => DataType::Float64, + // String + (DataType::Utf8, DataType::Utf8 | DataType::LargeUtf8) => rhs.clone(), + (DataType::LargeUtf8, DataType::Utf8) => lhs.clone(), + (DataType::LargeUtf8, DataType::LargeUtf8) => DataType::LargeUtf8, + (DataType::List(lhs_field), DataType::List(rhs_field)) => { + let field_type = + get_wider_type(lhs_field.data_type(), rhs_field.data_type())?; + assert_eq!(lhs_field.name(), rhs_field.name()); + let field_name = lhs_field.name(); + let nullable = lhs_field.is_nullable() | rhs_field.is_nullable(); + DataType::List(Arc::new(Field::new(field_name, field_type, nullable))) + } + (_, _) => { + return Err(DataFusionError::Execution(format!( + "Cannot concat types lhs: {:?}, rhs:{:?}", + lhs, rhs + ))) + } + }) +} + impl BuiltinScalarFunction { /// an allowlist of functions to take zero arguments, so that they will get special treatment /// while executing. @@ -539,10 +590,16 @@ impl BuiltinScalarFunction { List(field) => { if !field.data_type().equals_datatype(&Null) { let dims = self.return_dimension(input_expr_type.clone()); - if max_dims < dims { - max_dims = dims; - expr_type = input_expr_type.clone(); - } + expr_type = match max_dims.cmp(&dims) { + Ordering::Greater => expr_type, + Ordering::Equal => { + get_wider_type(&expr_type, input_expr_type)? + } + Ordering::Less => { + max_dims = dims; + input_expr_type.clone() + } + }; } } _ => { diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 1c526c7b4030..186315c698b5 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -45,7 +45,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, - OwnedTableReference, Result, ScalarValue, UnnestOptions, + OwnedTableReference, Result, ScalarValue, ToDFSchema, UnnestOptions, }; // backwards compatibility pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; @@ -1953,6 +1953,48 @@ impl Hash for TableScan { } } +impl TableScan { + /// Create projection schema using entries at the projection indices + pub fn project_schema(&self, projection: &[usize]) -> Result> { + let schema = self.source.schema(); + let projected_fields: Vec = projection + .iter() + .map(|i| { + DFField::from_qualified( + self.table_name.clone(), + schema.fields()[*i].clone(), + ) + }) + .collect(); + + // Find indices among previous schema + let old_indices = self + .projection + .clone() + .unwrap_or((0..self.projected_schema.fields().len()).collect()); + let new_proj = projection + .iter() + .map(|idx| { + old_indices + .iter() + .position(|old_idx| old_idx == idx) + // TODO: Remove this unwrap + .unwrap() + }) + .collect::>(); + let func_dependencies = self.projected_schema.functional_dependencies(); + let new_func_dependencies = func_dependencies + .project_functional_dependencies(&new_proj, projection.len()); + + let projected_schema = Arc::new( + projected_fields + .to_dfschema()? + .with_functional_dependencies(new_func_dependencies), + ); + Ok(projected_schema) + } +} + /// Apply Cross Join to two logical plans #[derive(Clone, PartialEq, Eq, Hash)] pub struct CrossJoin { diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index a854373e880d..541f5e9446e6 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -120,7 +120,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result } Operator::AtArrow | Operator::ArrowAt => { - array_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { + array_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { plan_datafusion_err!( "Cannot infer common array type for arrow operation {lhs} {op} {rhs}" ) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index c6b138f8ca36..17e75ccf41dd 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -111,12 +111,7 @@ impl CommonSubexprEliminate { projection: &Projection, config: &dyn OptimizerConfig, ) -> Result { - let Projection { - expr, - input, - schema, - .. - } = projection; + let Projection { expr, input, .. } = projection; let input_schema = Arc::clone(input.schema()); let mut expr_set = ExprSet::new(); let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?; @@ -124,10 +119,9 @@ impl CommonSubexprEliminate { let (mut new_expr, new_input) = self.rewrite_expr(&[expr], &[&arrays], input, &expr_set, config)?; - Ok(LogicalPlan::Projection(Projection::try_new_with_schema( + Ok(LogicalPlan::Projection(Projection::try_new( pop_expr(&mut new_expr)?, Arc::new(new_input), - schema.clone(), )?)) } @@ -201,7 +195,6 @@ impl CommonSubexprEliminate { group_expr, aggr_expr, input, - schema, .. } = aggregate; let mut expr_set = ExprSet::new(); @@ -247,11 +240,10 @@ impl CommonSubexprEliminate { let rewritten = pop_expr(&mut rewritten)?; if affected_id.is_empty() { - Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema( + Ok(LogicalPlan::Aggregate(Aggregate::try_new( Arc::new(new_input), new_group_expr, new_aggr_expr, - schema.clone(), )?)) } else { let mut agg_exprs = vec![]; diff --git a/datafusion/optimizer/src/merge_projection.rs b/datafusion/optimizer/src/merge_projection.rs index 408055b8e7d4..5f208924d03b 100644 --- a/datafusion/optimizer/src/merge_projection.rs +++ b/datafusion/optimizer/src/merge_projection.rs @@ -84,10 +84,9 @@ pub(super) fn merge_projection( Err(e) => Err(e), }) .collect::>>()?; - let new_plan = LogicalPlan::Projection(Projection::try_new_with_schema( + let new_plan = LogicalPlan::Projection(Projection::try_new( new_exprs, child_projection.input.clone(), - parent_projection.schema.clone(), )?); Ok(new_plan) } diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 839f6b5bb8f6..06c9ad9413c6 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -27,7 +27,7 @@ use arrow::datatypes::DataType; use arrow::error::Result as ArrowResult; use datafusion_common::ScalarValue::UInt8; use datafusion_common::{ - plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ToDFSchema, + plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::expr::{AggregateFunction, Alias}; use datafusion_expr::utils::exprlist_to_fields; @@ -598,15 +598,7 @@ fn push_down_scan( projection.into_iter().collect::>() }; - // create the projected schema - let projected_fields: Vec = projection - .iter() - .map(|i| { - DFField::from_qualified(scan.table_name.clone(), schema.fields()[*i].clone()) - }) - .collect(); - - let projected_schema = projected_fields.to_dfschema_ref()?; + let projected_schema = scan.project_schema(&projection)?; Ok(LogicalPlan::TableScan(TableScan { table_name: scan.table_name.clone(), diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index a9e65b3e7c77..d608cedee8eb 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -108,10 +108,19 @@ impl OptimizerRule for SingleDistinctToGroupBy { .iter() .enumerate() .map(|(i, group_expr)| { - let alias_str = format!("group_alias_{i}"); - let alias_expr = group_expr.clone().alias(&alias_str); - group_expr_alias - .push((alias_str, schema.fields()[i].clone())); + let (alias_expr, out_group_expr, original_name) = + if let Expr::Column(_) = group_expr { + (group_expr.clone(), group_expr.clone(), None) + } else { + let alias_str = format!("group_alias_{i}"); + let alias_expr = group_expr.clone().alias(&alias_str); + ( + alias_expr, + col(alias_str), + Some(schema.fields()[i].qualified_name()), + ) + }; + group_expr_alias.push((out_group_expr, original_name)); alias_expr }) .collect::>(); @@ -119,7 +128,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // and they can be referenced by the alias in the outer aggr plan let outer_group_exprs = group_expr_alias .iter() - .map(|(alias, _)| col(alias)) + .map(|(out_group_expr, _)| out_group_expr.clone()) .collect::>(); // replace the distinct arg with alias @@ -182,9 +191,13 @@ impl OptimizerRule for SingleDistinctToGroupBy { // - group_by aggr // - aggr expr let mut alias_expr: Vec = Vec::new(); - for (alias, original_field) in group_expr_alias { - alias_expr - .push(col(alias).alias(original_field.qualified_name())); + for (group_expr, original_field) in group_expr_alias { + let expr = if let Some(name) = original_field { + group_expr.alias(name) + } else { + group_expr + }; + alias_expr.push(expr); } for (i, expr) in new_aggr_exprs.iter().enumerate() { alias_expr.push(columnize_expr( @@ -202,13 +215,10 @@ impl OptimizerRule for SingleDistinctToGroupBy { new_aggr_exprs, )?); - Ok(Some(LogicalPlan::Projection( - Projection::try_new_with_schema( - alias_expr, - Arc::new(outer_aggr), - schema.clone(), - )?, - ))) + Ok(Some(LogicalPlan::Projection(Projection::try_new( + alias_expr, + Arc::new(outer_aggr), + )?))) } else { Ok(None) } @@ -362,9 +372,9 @@ mod tests { .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a, COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:UInt32, COUNT(alias1):Int64;N]\ - \n Aggregate: groupBy=[[test.a AS group_alias_0, test.b AS alias1]], aggr=[[]] [group_alias_0:UInt32, alias1:UInt32]\ + let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1)]] [a:UInt32, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -408,9 +418,9 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a, COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1), MAX(alias1)]] [group_alias_0:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ - \n Aggregate: groupBy=[[test.a AS group_alias_0, test.b AS alias1]], aggr=[[]] [group_alias_0:UInt32, alias1:UInt32]\ + let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1), MAX(alias1)]] [a:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 067a4cfdffc0..9fd5a7d07ec5 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -423,9 +423,13 @@ fn array(values: &[ColumnarValue]) -> Result { } // all nulls, set default data type as int32 Some(DataType::Null) => { - let null_arr = vec![ScalarValue::Int32(None); arrays.len()]; - let list_arr = ScalarValue::new_list(null_arr.as_slice(), &DataType::Int32); - Ok(Arc::new(list_arr)) + let nulls = arrays.len(); + let null_arr = NullArray::new(nulls); + let field = Arc::new(Field::new("item", DataType::Null, true)); + let offsets = OffsetBuffer::from_lengths([nulls]); + let values = Arc::new(null_arr) as ArrayRef; + let nulls = None; + Ok(Arc::new(ListArray::new(field, offsets, values, nulls))) } Some(data_type) => Ok(array_array(arrays.as_slice(), data_type)?), } @@ -988,7 +992,8 @@ macro_rules! general_repeat_list { /// Array_empty SQL function pub fn array_empty(args: &[ArrayRef]) -> Result { if args[0].as_any().downcast_ref::().is_some() { - return Ok(args[0].clone()); + // Make sure to return Boolean type. + return Ok(Arc::new(BooleanArray::new_null(args[0].len()))); } let array = as_list_array(&args[0])?; diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index f11bc5206eb4..621cb4a8f4c0 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2322,7 +2322,7 @@ select 1 || make_array(2, 3, 4), 1.0 || make_array(2.0, 3.0, 4.0), 'h' || make_a ## array containment operator # array containment operator with scalars #1 (at arrow) -query ??????? +query BBBBBBB select make_array(1,2,3) @> make_array(1,3), make_array(1,2,3) @> make_array(1,4), make_array([1,2], [3,4]) @> make_array([1,2]), @@ -2334,7 +2334,7 @@ select make_array(1,2,3) @> make_array(1,3), true false true false false false true # array containment operator with scalars #2 (arrow at) -query ??????? +query BBBBBBB select make_array(1,3) <@ make_array(1,2,3), make_array(1,4) <@ make_array(1,2,3), make_array([1,2]) <@ make_array([1,2], [3,4]), @@ -2465,7 +2465,7 @@ true query B select empty(make_array(NULL)); ---- -true +false # empty scalar function #4 query B diff --git a/datafusion/sqllogictest/test_files/tpch/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/q16.slt.part index fb9d98b76fe3..b93872929fe5 100644 --- a/datafusion/sqllogictest/test_files/tpch/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q16.slt.part @@ -52,9 +52,9 @@ limit 10; logical_plan Limit: skip=0, fetch=10 --Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type ASC NULLS LAST, part.p_size ASC NULLS LAST, fetch=10 -----Projection: group_alias_0 AS part.p_brand, group_alias_1 AS part.p_type, group_alias_2 AS part.p_size, COUNT(alias1) AS supplier_cnt -------Aggregate: groupBy=[[group_alias_0, group_alias_1, group_alias_2]], aggr=[[COUNT(alias1)]] ---------Aggregate: groupBy=[[part.p_brand AS group_alias_0, part.p_type AS group_alias_1, part.p_size AS group_alias_2, partsupp.ps_suppkey AS alias1]], aggr=[[]] +----Projection: part.p_brand, part.p_type, part.p_size, COUNT(alias1) AS supplier_cnt +------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[COUNT(alias1)]] +--------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey AS alias1]], aggr=[[]] ----------LeftAnti Join: partsupp.ps_suppkey = __correlated_sq_1.s_suppkey ------------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size --------------Inner Join: partsupp.ps_partkey = part.p_partkey @@ -69,15 +69,15 @@ physical_plan GlobalLimitExec: skip=0, fetch=10 --SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST], fetch=10 ----SortExec: TopK(fetch=10), expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST] -------ProjectionExec: expr=[group_alias_0@0 as part.p_brand, group_alias_1@1 as part.p_type, group_alias_2@2 as part.p_size, COUNT(alias1)@3 as supplier_cnt] ---------AggregateExec: mode=FinalPartitioned, gby=[group_alias_0@0 as group_alias_0, group_alias_1@1 as group_alias_1, group_alias_2@2 as group_alias_2], aggr=[COUNT(alias1)] +------ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, COUNT(alias1)@3 as supplier_cnt] +--------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)] ----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([group_alias_0@0, group_alias_1@1, group_alias_2@2], 4), input_partitions=4 ---------------AggregateExec: mode=Partial, gby=[group_alias_0@0 as group_alias_0, group_alias_1@1 as group_alias_1, group_alias_2@2 as group_alias_2], aggr=[COUNT(alias1)] -----------------AggregateExec: mode=FinalPartitioned, gby=[group_alias_0@0 as group_alias_0, group_alias_1@1 as group_alias_1, group_alias_2@2 as group_alias_2, alias1@3 as alias1], aggr=[] +------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2], 4), input_partitions=4 +--------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)] +----------------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, alias1@3 as alias1], aggr=[] ------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------RepartitionExec: partitioning=Hash([group_alias_0@0, group_alias_1@1, group_alias_2@2, alias1@3], 4), input_partitions=4 -----------------------AggregateExec: mode=Partial, gby=[p_brand@1 as group_alias_0, p_type@2 as group_alias_1, p_size@3 as group_alias_2, ps_suppkey@0 as alias1], aggr=[] +--------------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2, alias1@3], 4), input_partitions=4 +----------------------AggregateExec: mode=Partial, gby=[p_brand@1 as p_brand, p_type@2 as p_type, p_size@3 as p_size, ps_suppkey@0 as alias1], aggr=[] ------------------------CoalesceBatchesExec: target_batch_size=8192 --------------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(ps_suppkey@0, s_suppkey@0)] ----------------------------CoalesceBatchesExec: target_batch_size=8192 From b2509fdf53ef75985b722be54f929ae02b90913c Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 20 Oct 2023 11:30:30 +0300 Subject: [PATCH 02/15] Address todos --- datafusion/expr/src/built_in_function.rs | 53 +++++++++++++++++------- datafusion/expr/src/logical_plan/plan.rs | 12 +++--- 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index fbe69deddf1c..a15cffe252f2 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -316,35 +316,56 @@ fn function_to_name() -> &'static HashMap { }) } -// TODO: Enrich this implementation +/// Returns the wider type among lhs and rhs +/// Returns Error if types are incompatible fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { Ok(match (lhs, rhs) { (DataType::Null, _) => rhs.clone(), (_, DataType::Null) => lhs.clone(), - // Int - ( - DataType::Int8, - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => rhs.clone(), - (DataType::Int16, DataType::Int16 | DataType::Int32 | DataType::Int64) => { + // Same UInt types + (DataType::UInt8, DataType::UInt8) => DataType::UInt8, + (DataType::UInt16, DataType::UInt16) => DataType::UInt16, + (DataType::UInt32, DataType::UInt32) => DataType::UInt32, + (DataType::UInt64, DataType::UInt64) => DataType::UInt64, + // Right UInt is larger than left UInt + (DataType::UInt8, DataType::UInt16 | DataType::UInt32 | DataType::UInt64) => { rhs.clone() } - (DataType::Int32, DataType::Int32 | DataType::Int64) => rhs.clone(), + (DataType::UInt16, DataType::UInt32 | DataType::UInt64) => rhs.clone(), + (DataType::UInt32, DataType::UInt64) => rhs.clone(), + // Left UInt is larger than right UInt. + (DataType::UInt16 | DataType::UInt32 | DataType::UInt64, DataType::UInt8) => { + lhs.clone() + } + (DataType::UInt32 | DataType::UInt64, DataType::UInt16) => lhs.clone(), + (DataType::UInt64, DataType::UInt32) => lhs.clone(), + // Same Int types + (DataType::Int8, DataType::Int8) => DataType::Int8, + (DataType::Int16, DataType::Int16) => DataType::Int16, + (DataType::Int32, DataType::Int32) => DataType::Int32, + (DataType::Int64, DataType::Int64) => DataType::Int64, + // Right Int is larger than left Int + (DataType::Int8, DataType::Int16 | DataType::Int32 | DataType::Int64) => { + rhs.clone() + } + (DataType::Int16, DataType::Int32 | DataType::Int64) => rhs.clone(), + (DataType::Int32, DataType::Int64) => rhs.clone(), + // Left Int is larger than right Int. (DataType::Int16 | DataType::Int32 | DataType::Int64, DataType::Int8) => { lhs.clone() } (DataType::Int32 | DataType::Int64, DataType::Int16) => lhs.clone(), (DataType::Int64, DataType::Int32) => lhs.clone(), - (DataType::Int64, DataType::Int64) => DataType::Int64, - // Float - ( - DataType::Float16, - DataType::Float16 | DataType::Float32 | DataType::Float64, - ) => rhs.clone(), - (DataType::Float32, DataType::Float32 | DataType::Float64) => rhs.clone(), + // Same Float Types + (DataType::Float16, DataType::Float16) => DataType::Float16, + (DataType::Float32, DataType::Float32) => DataType::Float32, + (DataType::Float64, DataType::Float64) => DataType::Float64, + // Right Float is larger than left Float + (DataType::Float16, DataType::Float32 | DataType::Float64) => rhs.clone(), + (DataType::Float32, DataType::Float64) => rhs.clone(), + // Left Float is larger than right Float. (DataType::Float32 | DataType::Float64, DataType::Float16) => lhs.clone(), (DataType::Float64, DataType::Float32) => lhs.clone(), - (DataType::Float64, DataType::Float64) => DataType::Float64, // String (DataType::Utf8, DataType::Utf8 | DataType::LargeUtf8) => rhs.clone(), (DataType::LargeUtf8, DataType::Utf8) => lhs.clone(), diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 186315c698b5..30584f12a2a3 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -43,9 +43,10 @@ use datafusion_common::tree_node::{ VisitRecursion, }; use datafusion_common::{ - aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, - DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, - OwnedTableReference, Result, ScalarValue, ToDFSchema, UnnestOptions, + aggregate_functional_dependencies, internal_err, plan_datafusion_err, plan_err, + Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, + FunctionalDependencies, OwnedTableReference, Result, ScalarValue, ToDFSchema, + UnnestOptions, }; // backwards compatibility pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; @@ -1978,10 +1979,9 @@ impl TableScan { old_indices .iter() .position(|old_idx| old_idx == idx) - // TODO: Remove this unwrap - .unwrap() + .ok_or_else(|| plan_datafusion_err!("Refers to invalid index")) }) - .collect::>(); + .collect::>>()?; let func_dependencies = self.projected_schema.functional_dependencies(); let new_func_dependencies = func_dependencies .project_functional_dependencies(&new_proj, projection.len()); From 921bd5c9704d65fe2443234a1c300128ea4ad447 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 20 Oct 2023 11:44:25 +0300 Subject: [PATCH 03/15] Update comments --- datafusion/expr/src/built_in_function.rs | 11 +++++------ .../optimizer/src/single_distinct_to_groupby.rs | 3 +++ datafusion/physical-expr/src/array_expressions.rs | 12 ++++-------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index a15cffe252f2..ecada6b87808 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -317,6 +317,7 @@ fn function_to_name() -> &'static HashMap { } /// Returns the wider type among lhs and rhs +/// Wider type is the type that can represent other type without loss safely. /// Returns Error if types are incompatible fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { Ok(match (lhs, rhs) { @@ -378,12 +379,10 @@ fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { let nullable = lhs_field.is_nullable() | rhs_field.is_nullable(); DataType::List(Arc::new(Field::new(field_name, field_type, nullable))) } - (_, _) => { - return Err(DataFusionError::Execution(format!( - "Cannot concat types lhs: {:?}, rhs:{:?}", - lhs, rhs - ))) - } + (_, _) => return Err(DataFusionError::Execution(format!( + "There is no wider type (that can represent other) among lhs: {:?}, rhs:{:?}", + lhs, rhs + ))), }) } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index d608cedee8eb..793d73131d5a 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -110,8 +110,11 @@ impl OptimizerRule for SingleDistinctToGroupBy { .map(|(i, group_expr)| { let (alias_expr, out_group_expr, original_name) = if let Expr::Column(_) = group_expr { + // For Column expressions we can use existing expression as is. (group_expr.clone(), group_expr.clone(), None) } else { + // For complex expression write is as alias, to be able to refer + // if from parent operators successfully. let alias_str = format!("group_alias_{i}"); let alias_expr = group_expr.clone().alias(&alias_str); ( diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 9fd5a7d07ec5..49efde72cbc5 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -421,15 +421,11 @@ fn array(values: &[ColumnarValue]) -> Result { let list_arr = ScalarValue::new_list(&[], &DataType::Null); Ok(Arc::new(list_arr)) } - // all nulls, set default data type as int32 + // all nulls, return a NullArray. Some(DataType::Null) => { - let nulls = arrays.len(); - let null_arr = NullArray::new(nulls); - let field = Arc::new(Field::new("item", DataType::Null, true)); - let offsets = OffsetBuffer::from_lengths([nulls]); - let values = Arc::new(null_arr) as ArrayRef; - let nulls = None; - Ok(Arc::new(ListArray::new(field, offsets, values, nulls))) + let null_arr = vec![ScalarValue::Null; arrays.len()]; + let list_arr = ScalarValue::new_list(null_arr.as_slice(), &DataType::Null); + Ok(Arc::new(list_arr)) } Some(data_type) => Ok(array_array(arrays.as_slice(), data_type)?), } From 7d8a91115573fc1115ee6df06d83b12e4c0ceae4 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 20 Oct 2023 13:40:54 +0300 Subject: [PATCH 04/15] Simplifications --- datafusion/expr/src/built_in_function.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index ecada6b87808..72f869403fc1 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -379,10 +379,12 @@ fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { let nullable = lhs_field.is_nullable() | rhs_field.is_nullable(); DataType::List(Arc::new(Field::new(field_name, field_type, nullable))) } - (_, _) => return Err(DataFusionError::Execution(format!( + (_, _) => { + return Err(DataFusionError::Execution(format!( "There is no wider type (that can represent other) among lhs: {:?}, rhs:{:?}", lhs, rhs - ))), + ))) + } }) } From 80092c52cf66620ca07bd4c1fd8f30898ba57aa6 Mon Sep 17 00:00:00 2001 From: berkaysynnada Date: Fri, 20 Oct 2023 14:37:13 +0300 Subject: [PATCH 05/15] Minor simplifications --- datafusion/expr/src/built_in_function.rs | 24 +++++------------------- datafusion/expr/src/logical_plan/plan.rs | 6 ++---- 2 files changed, 7 insertions(+), 23 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 72f869403fc1..c70f8f88724a 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -316,18 +316,14 @@ fn function_to_name() -> &'static HashMap { }) } -/// Returns the wider type among lhs and rhs -/// Wider type is the type that can represent other type without loss safely. -/// Returns Error if types are incompatible +/// Returns the wider type among lhs and rhs. +/// Wider type is the type that can safely represent the other type without information loss. +/// Returns Error if types are incompatible. fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { Ok(match (lhs, rhs) { + (lhs, rhs) if lhs == rhs => lhs.clone(), (DataType::Null, _) => rhs.clone(), (_, DataType::Null) => lhs.clone(), - // Same UInt types - (DataType::UInt8, DataType::UInt8) => DataType::UInt8, - (DataType::UInt16, DataType::UInt16) => DataType::UInt16, - (DataType::UInt32, DataType::UInt32) => DataType::UInt32, - (DataType::UInt64, DataType::UInt64) => DataType::UInt64, // Right UInt is larger than left UInt (DataType::UInt8, DataType::UInt16 | DataType::UInt32 | DataType::UInt64) => { rhs.clone() @@ -340,11 +336,6 @@ fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { } (DataType::UInt32 | DataType::UInt64, DataType::UInt16) => lhs.clone(), (DataType::UInt64, DataType::UInt32) => lhs.clone(), - // Same Int types - (DataType::Int8, DataType::Int8) => DataType::Int8, - (DataType::Int16, DataType::Int16) => DataType::Int16, - (DataType::Int32, DataType::Int32) => DataType::Int32, - (DataType::Int64, DataType::Int64) => DataType::Int64, // Right Int is larger than left Int (DataType::Int8, DataType::Int16 | DataType::Int32 | DataType::Int64) => { rhs.clone() @@ -357,10 +348,6 @@ fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { } (DataType::Int32 | DataType::Int64, DataType::Int16) => lhs.clone(), (DataType::Int64, DataType::Int32) => lhs.clone(), - // Same Float Types - (DataType::Float16, DataType::Float16) => DataType::Float16, - (DataType::Float32, DataType::Float32) => DataType::Float32, - (DataType::Float64, DataType::Float64) => DataType::Float64, // Right Float is larger than left Float (DataType::Float16, DataType::Float32 | DataType::Float64) => rhs.clone(), (DataType::Float32, DataType::Float64) => rhs.clone(), @@ -368,9 +355,8 @@ fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { (DataType::Float32 | DataType::Float64, DataType::Float16) => lhs.clone(), (DataType::Float64, DataType::Float32) => lhs.clone(), // String - (DataType::Utf8, DataType::Utf8 | DataType::LargeUtf8) => rhs.clone(), + (DataType::Utf8, DataType::LargeUtf8) => rhs.clone(), (DataType::LargeUtf8, DataType::Utf8) => lhs.clone(), - (DataType::LargeUtf8, DataType::LargeUtf8) => DataType::LargeUtf8, (DataType::List(lhs_field), DataType::List(rhs_field)) => { let field_type = get_wider_type(lhs_field.data_type(), rhs_field.data_type())?; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 30584f12a2a3..dcfb3d903bdd 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1969,10 +1969,8 @@ impl TableScan { .collect(); // Find indices among previous schema - let old_indices = self - .projection - .clone() - .unwrap_or((0..self.projected_schema.fields().len()).collect()); + let projected_range = (0..self.projected_schema.fields().len()).collect(); + let old_indices = self.projection.as_ref().unwrap_or(&projected_range); let new_proj = projection .iter() .map(|idx| { From 1648b1b7dee9acb5c467dff5b253da2f19fe17a6 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 20 Oct 2023 14:58:55 +0300 Subject: [PATCH 06/15] Address reviews --- datafusion/expr/src/built_in_function.rs | 13 +++++++++---- datafusion/expr/src/type_coercion/binary.rs | 3 +++ .../optimizer/src/single_distinct_to_groupby.rs | 15 +++++++++++++++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index c70f8f88724a..141e60ce1ce6 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -26,7 +26,8 @@ use crate::{ }; use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; use datafusion_common::{ - internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, + exec_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, + Result, }; use std::cmp::Ordering; use std::collections::HashMap; @@ -360,16 +361,20 @@ fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { (DataType::List(lhs_field), DataType::List(rhs_field)) => { let field_type = get_wider_type(lhs_field.data_type(), rhs_field.data_type())?; + if lhs_field.name() != rhs_field.name() { + return Err(exec_datafusion_err!( + "There is no wider type (that can represent other) among lhs: {:?}, rhs:{:?}", + lhs, rhs)); + } assert_eq!(lhs_field.name(), rhs_field.name()); let field_name = lhs_field.name(); let nullable = lhs_field.is_nullable() | rhs_field.is_nullable(); DataType::List(Arc::new(Field::new(field_name, field_type, nullable))) } (_, _) => { - return Err(DataFusionError::Execution(format!( + return Err(exec_datafusion_err!( "There is no wider type (that can represent other) among lhs: {:?}, rhs:{:?}", - lhs, rhs - ))) + lhs, rhs)); } }) } diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 541f5e9446e6..4279f3343355 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -120,6 +120,9 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result } Operator::AtArrow | Operator::ArrowAt => { + // ArrowAt and AtArrow check for whether one array ic contained in another. + // The result type is boolean. Signature::comparison defines this signature. + // Operation has nothing to do with comparison array_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { plan_datafusion_err!( "Cannot infer common array type for arrow operation {lhs} {op} {rhs}" diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 793d73131d5a..22f4e3beb5bb 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -115,6 +115,21 @@ impl OptimizerRule for SingleDistinctToGroupBy { } else { // For complex expression write is as alias, to be able to refer // if from parent operators successfully. + // Consider plan below. + // + // Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ + // --Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ + // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] + // + // First aggregate(from bottom) refers to `test.a` column. + // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate. + // If we were to write plan above as below without alias + // + // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ + // --Aggregate: groupBy=[[test.a + Int32(1), test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ + // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] + // + // Second aggregate refers to the `test.a + Int32(1)` expression However, its input do not have `test.a` expression in it. let alias_str = format!("group_alias_{i}"); let alias_expr = group_expr.clone().alias(&alias_str); ( From a9e44391347fc6edd54b36e8a74a0e4be9be4cfe Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 20 Oct 2023 17:29:54 +0300 Subject: [PATCH 07/15] Add TableScan constructor --- .../common/src/functional_dependencies.rs | 17 ++++ datafusion/expr/src/logical_plan/builder.rs | 73 +-------------- datafusion/expr/src/logical_plan/plan.rs | 91 +++++++++++-------- .../optimizer/src/push_down_projection.rs | 18 ++-- 4 files changed, 84 insertions(+), 115 deletions(-) diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 869709bc8dfc..fbddcddab4bc 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -558,4 +558,21 @@ mod tests { assert_eq!(iter.next(), Some(&Constraint::Unique(vec![20]))); assert_eq!(iter.next(), None); } + + #[test] + fn test_get_updated_id_keys() { + let fund_dependencies = + FunctionalDependencies::new(vec![FunctionalDependence::new( + vec![1], + vec![0, 1, 2], + true, + )]); + let res = fund_dependencies.project_functional_dependencies(&[1, 2], 2); + let expected = FunctionalDependencies::new(vec![FunctionalDependence::new( + vec![0], + vec![0, 1], + true, + )]); + assert_eq!(res, expected); + } } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index cd50dbe79cfd..a6bae960aae9 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -46,8 +46,7 @@ use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::UnnestOptions; use datafusion_common::{ display::ToStringifiedPlan, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, - FileType, FunctionalDependencies, OwnedTableReference, Result, ScalarValue, - TableReference, ToDFSchema, + FileType, OwnedTableReference, Result, ScalarValue, TableReference, ToDFSchema, }; use datafusion_common::{plan_datafusion_err, plan_err}; use std::any::Any; @@ -283,52 +282,9 @@ impl LogicalPlanBuilder { projection: Option>, filters: Vec, ) -> Result { - let table_name = table_name.into(); - - if table_name.table().is_empty() { - return plan_err!("table_name cannot be empty"); - } - - let schema = table_source.schema(); - let func_dependencies = FunctionalDependencies::new_from_constraints( - table_source.constraints(), - schema.fields.len(), - ); - - let projected_schema = projection - .as_ref() - .map(|p| { - let projected_func_dependencies = - func_dependencies.project_functional_dependencies(p, p.len()); - DFSchema::new_with_metadata( - p.iter() - .map(|i| { - DFField::from_qualified( - table_name.clone(), - schema.field(*i).clone(), - ) - }) - .collect(), - schema.metadata().clone(), - ) - .map(|df_schema| { - df_schema.with_functional_dependencies(projected_func_dependencies) - }) - }) - .unwrap_or_else(|| { - DFSchema::try_from_qualified_schema(table_name.clone(), &schema).map( - |df_schema| df_schema.with_functional_dependencies(func_dependencies), - ) - })?; - - let table_scan = LogicalPlan::TableScan(TableScan { - table_name, - source: table_source, - projected_schema: Arc::new(projected_schema), - projection, - filters, - fetch: None, - }); + let table_scan = + TableScan::try_new(table_name, table_source, projection, filters, None)?; + let table_scan = LogicalPlan::TableScan(table_scan); Ok(Self::from(table_scan)) } @@ -1548,9 +1504,7 @@ mod tests { use super::*; use arrow::datatypes::{DataType, Field}; - use datafusion_common::{ - FunctionalDependence, OwnedTableReference, SchemaError, TableReference, - }; + use datafusion_common::{OwnedTableReference, SchemaError, TableReference}; #[test] fn plan_builder_simple() -> Result<()> { @@ -2051,21 +2005,4 @@ mod tests { Ok(()) } - - #[test] - fn test_get_updated_id_keys() { - let fund_dependencies = - FunctionalDependencies::new(vec![FunctionalDependence::new( - vec![1], - vec![0, 1, 2], - true, - )]); - let res = fund_dependencies.project_functional_dependencies(&[1, 2], 2); - let expected = FunctionalDependencies::new(vec![FunctionalDependence::new( - vec![0], - vec![0, 1], - true, - )]); - assert_eq!(res, expected); - } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index dcfb3d903bdd..ec2cf694516f 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -43,10 +43,9 @@ use datafusion_common::tree_node::{ VisitRecursion, }; use datafusion_common::{ - aggregate_functional_dependencies, internal_err, plan_datafusion_err, plan_err, - Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, - FunctionalDependencies, OwnedTableReference, Result, ScalarValue, ToDFSchema, - UnnestOptions, + aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, + DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, + OwnedTableReference, Result, ScalarValue, UnnestOptions, }; // backwards compatibility pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; @@ -1955,41 +1954,59 @@ impl Hash for TableScan { } impl TableScan { - /// Create projection schema using entries at the projection indices - pub fn project_schema(&self, projection: &[usize]) -> Result> { - let schema = self.source.schema(); - let projected_fields: Vec = projection - .iter() - .map(|i| { - DFField::from_qualified( - self.table_name.clone(), - schema.fields()[*i].clone(), - ) - }) - .collect(); + /// Initialize TableScan with appropriate schema from the given + /// arguments. + pub fn try_new( + table_name: impl Into, + table_source: Arc, + projection: Option>, + filters: Vec, + fetch: Option, + ) -> Result { + let table_name = table_name.into(); - // Find indices among previous schema - let projected_range = (0..self.projected_schema.fields().len()).collect(); - let old_indices = self.projection.as_ref().unwrap_or(&projected_range); - let new_proj = projection - .iter() - .map(|idx| { - old_indices - .iter() - .position(|old_idx| old_idx == idx) - .ok_or_else(|| plan_datafusion_err!("Refers to invalid index")) - }) - .collect::>>()?; - let func_dependencies = self.projected_schema.functional_dependencies(); - let new_func_dependencies = func_dependencies - .project_functional_dependencies(&new_proj, projection.len()); - - let projected_schema = Arc::new( - projected_fields - .to_dfschema()? - .with_functional_dependencies(new_func_dependencies), + if table_name.table().is_empty() { + return plan_err!("table_name cannot be empty"); + } + let schema = table_source.schema(); + let func_dependencies = FunctionalDependencies::new_from_constraints( + table_source.constraints(), + schema.fields.len(), ); - Ok(projected_schema) + let projected_schema = projection + .as_ref() + .map(|p| { + let projected_func_dependencies = + func_dependencies.project_functional_dependencies(p, p.len()); + DFSchema::new_with_metadata( + p.iter() + .map(|i| { + DFField::from_qualified( + table_name.clone(), + schema.field(*i).clone(), + ) + }) + .collect(), + schema.metadata().clone(), + ) + .map(|df_schema| { + df_schema.with_functional_dependencies(projected_func_dependencies) + }) + }) + .unwrap_or_else(|| { + DFSchema::try_from_qualified_schema(table_name.clone(), &schema).map( + |df_schema| df_schema.with_functional_dependencies(func_dependencies), + ) + })?; + let projected_schema = Arc::new(projected_schema); + Ok(Self { + table_name, + source: table_source, + projection, + projected_schema, + filters, + fetch, + }) } } diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 06c9ad9413c6..038be22bc77c 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -598,16 +598,14 @@ fn push_down_scan( projection.into_iter().collect::>() }; - let projected_schema = scan.project_schema(&projection)?; - - Ok(LogicalPlan::TableScan(TableScan { - table_name: scan.table_name.clone(), - source: scan.source.clone(), - projection: Some(projection), - projected_schema, - filters: scan.filters.clone(), - fetch: scan.fetch, - })) + let table_scan = TableScan::try_new( + scan.table_name.clone(), + scan.source.clone(), + Some(projection), + scan.filters.clone(), + scan.fetch, + )?; + Ok(LogicalPlan::TableScan(table_scan)) } fn restrict_outputs( From c939460c05fd0e9470444d80ab6d1e4d8c85b71e Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 20 Oct 2023 17:59:39 +0300 Subject: [PATCH 08/15] Minor changes --- datafusion/expr/src/logical_plan/plan.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ec2cf694516f..59bbfe86fe8c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -549,13 +549,11 @@ impl LogicalPlan { LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, - schema, .. - }) => Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema( + }) => Ok(LogicalPlan::Aggregate(Aggregate::try_new( Arc::new(inputs[0].clone()), group_expr.to_vec(), aggr_expr.to_vec(), - schema.clone(), )?)), _ => self.with_new_exprs(self.expressions(), inputs), } @@ -710,17 +708,14 @@ impl LogicalPlan { schema: schema.clone(), })) } - LogicalPlan::Aggregate(Aggregate { - group_expr, schema, .. - }) => { + LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => { // group exprs are the first expressions let agg_expr = expr.split_off(group_expr.len()); - Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema( + Ok(LogicalPlan::Aggregate(Aggregate::try_new( Arc::new(inputs[0].clone()), expr, agg_expr, - schema.clone(), )?)) } LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort { From cecb82fb7b93d9050a3066c6b30ad1921bafedf0 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 20 Oct 2023 18:02:52 +0300 Subject: [PATCH 09/15] make try_new_with_schema method of Aggregate private --- datafusion/expr/src/logical_plan/plan.rs | 2 +- datafusion/optimizer/src/replace_distinct_aggregate.rs | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 59bbfe86fe8c..6e418411b380 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2183,7 +2183,7 @@ impl Aggregate { /// /// This method should only be called when you are absolutely sure that the schema being /// provided is correct for the aggregate. If in doubt, call [try_new](Self::try_new) instead. - pub fn try_new_with_schema( + fn try_new_with_schema( input: Arc, group_expr: Vec, aggr_expr: Vec, diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index f58d4b159745..9bbbc52aa5db 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -54,11 +54,10 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { match plan { LogicalPlan::Distinct(Distinct { input }) => { let group_expr = expand_wildcard(input.schema(), input, None)?; - let aggregate = LogicalPlan::Aggregate(Aggregate::try_new_with_schema( + let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), group_expr, vec![], - input.schema().clone(), // input schema and aggregate schema are the same in this case )?); Ok(Some(aggregate)) } From 70d0a267ab4b79026fb8e7f70be4c1d1c90f5639 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 20 Oct 2023 18:14:12 +0300 Subject: [PATCH 10/15] Use projection try_new instead of try_new_schema --- datafusion/expr/src/logical_plan/plan.rs | 13 ++++--------- datafusion/expr/src/tree_node/expr.rs | 18 +++++++++++++++--- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 6e418411b380..cf95879ff4e5 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -531,10 +531,9 @@ impl LogicalPlan { // so we don't need to recompute Schema. match &self { LogicalPlan::Projection(projection) => { - Ok(LogicalPlan::Projection(Projection::try_new_with_schema( + Ok(LogicalPlan::Projection(Projection::try_new( projection.expr.to_vec(), Arc::new(inputs[0].clone()), - projection.schema.clone(), )?)) } LogicalPlan::Window(Window { @@ -588,13 +587,9 @@ impl LogicalPlan { inputs: &[LogicalPlan], ) -> Result { match self { - LogicalPlan::Projection(Projection { schema, .. }) => { - Ok(LogicalPlan::Projection(Projection::try_new_with_schema( - expr, - Arc::new(inputs[0].clone()), - schema.clone(), - )?)) - } + LogicalPlan::Projection(Projection { .. }) => Ok(LogicalPlan::Projection( + Projection::try_new(expr, Arc::new(inputs[0].clone()))?, + )), LogicalPlan::Dml(DmlStatement { table_name, table_schema, diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index f74cc164a7a5..06581cbf434a 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -22,7 +22,7 @@ use crate::expr::{ GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, ScalarUDF, Sort, TryCast, WindowFunction, }; -use crate::Expr; +use crate::{Expr, GetFieldAccess}; use datafusion_common::tree_node::VisitRecursion; use datafusion_common::{tree_node::TreeNode, Result}; @@ -47,8 +47,20 @@ impl TreeNode for Expr { | Expr::TryCast(TryCast { expr, .. }) | Expr::Sort(Sort { expr, .. }) | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref().clone()], - Expr::GetIndexedField(GetIndexedField { expr, .. }) => { - vec![expr.as_ref().clone()] + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + let mut exprs = match field{ + GetFieldAccess::ListIndex {key} => { + vec![key.as_ref().clone()] + }, + GetFieldAccess::ListRange {start, stop} => { + vec![start.as_ref().clone(), stop.as_ref().clone()] + } + GetFieldAccess::NamedStructField {name: _name} => { + vec![] + } + }; + exprs.push(expr.as_ref().clone()); + exprs } Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(), From 7273cc8ed632e89fe0aae4bc506e26a79a3fb3cf Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 23 Oct 2023 11:47:30 +0300 Subject: [PATCH 11/15] Simplifications, add comment --- datafusion/expr/src/logical_plan/plan.rs | 10 +++++++++- .../optimizer/src/common_subexpr_eliminate.rs | 2 ++ datafusion/optimizer/src/merge_projection.rs | 1 + .../optimizer/src/single_distinct_to_groupby.rs | 15 ++++++++------- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index cf95879ff4e5..9dbf72d15a00 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -531,6 +531,9 @@ impl LogicalPlan { // so we don't need to recompute Schema. match &self { LogicalPlan::Projection(projection) => { + // Schema of the projection may change + // when its input changes. Hence we should use + // `try_new` method instead of `try_new_with_schema`. Ok(LogicalPlan::Projection(Projection::try_new( projection.expr.to_vec(), Arc::new(inputs[0].clone()), @@ -550,6 +553,9 @@ impl LogicalPlan { aggr_expr, .. }) => Ok(LogicalPlan::Aggregate(Aggregate::try_new( + // Schema of the aggregate may change + // when its input changes. Hence we should use + // `try_new` method instead of `try_new_with_schema`. Arc::new(inputs[0].clone()), group_expr.to_vec(), aggr_expr.to_vec(), @@ -587,6 +593,8 @@ impl LogicalPlan { inputs: &[LogicalPlan], ) -> Result { match self { + // Since expr may be different than the previous expr, schema of the projection + // may change. We need to use try_new method instead of try_new_with_schema method. LogicalPlan::Projection(Projection { .. }) => Ok(LogicalPlan::Projection( Projection::try_new(expr, Arc::new(inputs[0].clone()))?, )), @@ -2178,7 +2186,7 @@ impl Aggregate { /// /// This method should only be called when you are absolutely sure that the schema being /// provided is correct for the aggregate. If in doubt, call [try_new](Self::try_new) instead. - fn try_new_with_schema( + pub fn try_new_with_schema( input: Arc, group_expr: Vec, aggr_expr: Vec, diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 17e75ccf41dd..78ba1cfe6d7e 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -119,6 +119,7 @@ impl CommonSubexprEliminate { let (mut new_expr, new_input) = self.rewrite_expr(&[expr], &[&arrays], input, &expr_set, config)?; + // Since projection expr changes, schema changes also. Use try_new method. Ok(LogicalPlan::Projection(Projection::try_new( pop_expr(&mut new_expr)?, Arc::new(new_input), @@ -240,6 +241,7 @@ impl CommonSubexprEliminate { let rewritten = pop_expr(&mut rewritten)?; if affected_id.is_empty() { + // Since group_epxr changes, schema changes also. Use try_new method. Ok(LogicalPlan::Aggregate(Aggregate::try_new( Arc::new(new_input), new_group_expr, diff --git a/datafusion/optimizer/src/merge_projection.rs b/datafusion/optimizer/src/merge_projection.rs index 5f208924d03b..e8388b31655d 100644 --- a/datafusion/optimizer/src/merge_projection.rs +++ b/datafusion/optimizer/src/merge_projection.rs @@ -84,6 +84,7 @@ pub(super) fn merge_projection( Err(e) => Err(e), }) .collect::>>()?; + // Use try_new, since schema changes with changing expressions. let new_plan = LogicalPlan::Projection(Projection::try_new( new_exprs, child_projection.input.clone(), diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 22f4e3beb5bb..cac373f50ea7 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -103,8 +103,10 @@ impl OptimizerRule for SingleDistinctToGroupBy { }) => { if is_single_distinct_agg(plan)? && !contains_grouping_set(group_expr) { // alias all original group_by exprs - let mut group_expr_alias = Vec::with_capacity(group_expr.len()); - let mut inner_group_exprs = group_expr + let (mut inner_group_exprs, out_group_expr_with_alias): ( + Vec, + Vec<(Expr, Option)>, + ) = group_expr .iter() .enumerate() .map(|(i, group_expr)| { @@ -138,13 +140,12 @@ impl OptimizerRule for SingleDistinctToGroupBy { Some(schema.fields()[i].qualified_name()), ) }; - group_expr_alias.push((out_group_expr, original_name)); - alias_expr + (alias_expr, (out_group_expr, original_name)) }) - .collect::>(); + .unzip(); // and they can be referenced by the alias in the outer aggr plan - let outer_group_exprs = group_expr_alias + let outer_group_exprs = out_group_expr_with_alias .iter() .map(|(out_group_expr, _)| out_group_expr.clone()) .collect::>(); @@ -209,7 +210,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // - group_by aggr // - aggr expr let mut alias_expr: Vec = Vec::new(); - for (group_expr, original_field) in group_expr_alias { + for (group_expr, original_field) in out_group_expr_with_alias { let expr = if let Some(name) = original_field { group_expr.alias(name) } else { From cb187bf7a88326ab8e1a41792fb4c3f7e97ff63a Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Tue, 24 Oct 2023 15:54:29 +0300 Subject: [PATCH 12/15] Review changes --- datafusion/expr/src/built_in_function.rs | 41 +++--- datafusion/expr/src/logical_plan/builder.rs | 118 ++++++++---------- datafusion/expr/src/logical_plan/plan.rs | 59 ++++----- datafusion/expr/src/tree_node/expr.rs | 18 +-- datafusion/expr/src/type_coercion/binary.rs | 7 +- .../optimizer/src/common_subexpr_eliminate.rs | 24 ++-- datafusion/optimizer/src/merge_projection.rs | 7 +- .../optimizer/src/push_down_projection.rs | 36 +++--- .../src/replace_distinct_aggregate.rs | 7 +- .../src/single_distinct_to_groupby.rs | 99 ++++++++------- .../physical-expr/src/array_expressions.rs | 36 +++--- 11 files changed, 208 insertions(+), 244 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 141e60ce1ce6..c576afb37beb 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -17,6 +17,12 @@ //! Built-in functions module contains all the built-in functions definitions. +use std::cmp::Ordering; +use std::collections::HashMap; +use std::fmt; +use std::str::FromStr; +use std::sync::{Arc, OnceLock}; + use crate::nullif::SUPPORTED_NULLIF_TYPES; use crate::signature::TIMEZONE_WILDCARD; use crate::type_coercion::functions::data_types; @@ -24,16 +30,13 @@ use crate::{ conditional_expressions, struct_expressions, utils, FuncMonotonicity, Signature, TypeSignature, Volatility, }; + use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; use datafusion_common::{ exec_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, }; -use std::cmp::Ordering; -use std::collections::HashMap; -use std::fmt; -use std::str::FromStr; -use std::sync::{Arc, OnceLock}; + use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -363,8 +366,9 @@ fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { get_wider_type(lhs_field.data_type(), rhs_field.data_type())?; if lhs_field.name() != rhs_field.name() { return Err(exec_datafusion_err!( - "There is no wider type (that can represent other) among lhs: {:?}, rhs:{:?}", - lhs, rhs)); + "There is no wider type (that can represent other) among lhs: {:?}, rhs:{:?}", + lhs, rhs + )); } assert_eq!(lhs_field.name(), rhs_field.name()); let field_name = lhs_field.name(); @@ -373,8 +377,9 @@ fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { } (_, _) => { return Err(exec_datafusion_err!( - "There is no wider type (that can represent other) among lhs: {:?}, rhs:{:?}", - lhs, rhs)); + "There is no wider type (that can represent other) among lhs: {:?}, rhs:{:?}", + lhs, rhs + )); } }) } @@ -532,18 +537,14 @@ impl BuiltinScalarFunction { /// * `List(Int64)` has dimension 2 /// * `List(List(Int64))` has dimension 3 /// * etc. - fn return_dimension(self, input_expr_type: DataType) -> u64 { - let mut res: u64 = 1; + fn return_dimension(self, input_expr_type: &DataType) -> u64 { + let mut result: u64 = 1; let mut current_data_type = input_expr_type; - loop { - match current_data_type { - DataType::List(field) => { - current_data_type = field.data_type().clone(); - res += 1; - } - _ => return res, - } + while let DataType::List(field) = current_data_type { + current_data_type = field.data_type(); + result += 1; } + result } /// Returns the output [`DataType`] of this function @@ -602,7 +603,7 @@ impl BuiltinScalarFunction { match input_expr_type { List(field) => { if !field.data_type().equals_datatype(&Null) { - let dims = self.return_dimension(input_expr_type.clone()); + let dims = self.return_dimension(input_expr_type); expr_type = match max_dims.cmp(&dims) { Ordering::Greater => expr_type, Ordering::Equal => { diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index a6bae960aae9..9ce1d203d1c9 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -17,6 +17,13 @@ //! This module provides a builder for creating LogicalPlans +use std::any::Any; +use std::cmp::Ordering; +use std::collections::{HashMap, HashSet}; +use std::convert::TryFrom; +use std::iter::zip; +use std::sync::Arc; + use crate::dml::{CopyOptions, CopyTo}; use crate::expr::Alias; use crate::expr_rewriter::{ @@ -24,37 +31,29 @@ use crate::expr_rewriter::{ normalize_col_with_schemas_and_ambiguity_check, normalize_cols, rewrite_sort_cols_by_aggs, }; +use crate::logical_plan::{ + Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join, + JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, + Window, +}; use crate::type_coercion::binary::comparison_coercion; -use crate::utils::{columnize_expr, compare_sort_expr}; -use crate::{ - and, binary_expr, DmlStatement, Operator, TableProviderFilterPushDown, WriteOp, +use crate::utils::{ + can_hash, columnize_expr, compare_sort_expr, expand_qualified_wildcard, + expand_wildcard, find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, }; use crate::{ - logical_plan::{ - Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join, - JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, - Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, - Window, - }, - utils::{ - can_hash, expand_qualified_wildcard, expand_wildcard, - find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, - }, - Expr, ExprSchemable, TableSource, + and, binary_expr, DmlStatement, Expr, ExprSchemable, Operator, + TableProviderFilterPushDown, TableSource, WriteOp, }; + use arrow::datatypes::{DataType, Schema, SchemaRef}; -use datafusion_common::UnnestOptions; +use datafusion_common::display::ToStringifiedPlan; use datafusion_common::{ - display::ToStringifiedPlan, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, - FileType, OwnedTableReference, Result, ScalarValue, TableReference, ToDFSchema, + plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, + DataFusionError, FileType, OwnedTableReference, Result, ScalarValue, TableReference, + ToDFSchema, UnnestOptions, }; -use datafusion_common::{plan_datafusion_err, plan_err}; -use std::any::Any; -use std::cmp::Ordering; -use std::collections::{HashMap, HashSet}; -use std::convert::TryFrom; -use std::iter::zip; -use std::sync::Arc; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -282,10 +281,9 @@ impl LogicalPlanBuilder { projection: Option>, filters: Vec, ) -> Result { - let table_scan = - TableScan::try_new(table_name, table_source, projection, filters, None)?; - let table_scan = LogicalPlan::TableScan(table_scan); - Ok(Self::from(table_scan)) + TableScan::try_new(table_name, table_source, projection, filters, None) + .map(LogicalPlan::TableScan) + .map(Self::from) } /// Wrap a plan in a window @@ -330,7 +328,7 @@ impl LogicalPlanBuilder { self, expr: impl IntoIterator>, ) -> Result { - Ok(Self::from(project(self.plan, expr)?)) + project(self.plan, expr).map(Self::from) } /// Select the given column indices @@ -346,10 +344,9 @@ impl LogicalPlanBuilder { /// Apply a filter pub fn filter(self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?; - Ok(Self::from(LogicalPlan::Filter(Filter::try_new( - expr, - Arc::new(self.plan), - )?))) + Filter::try_new(expr, Arc::new(self.plan)) + .map(LogicalPlan::Filter) + .map(Self::from) } /// Make a builder for a prepare logical plan from the builder's plan @@ -377,7 +374,7 @@ impl LogicalPlanBuilder { /// Apply an alias pub fn alias(self, alias: impl Into) -> Result { - Ok(Self::from(subquery_alias(self.plan, alias)?)) + subquery_alias(self.plan, alias).map(Self::from) } /// Add missing sort columns to all downstream projection @@ -432,7 +429,7 @@ impl LogicalPlanBuilder { Self::ambiguous_distinct_check(&missing_exprs, missing_cols, &expr)?; } expr.extend(missing_exprs); - Ok(project((*input).clone(), expr)?) + project((*input).clone(), expr) } _ => { let is_distinct = @@ -539,15 +536,14 @@ impl LogicalPlanBuilder { fetch: None, }); - Ok(Self::from(LogicalPlan::Projection(Projection::try_new( - new_expr, - Arc::new(sort_plan), - )?))) + Projection::try_new(new_expr, Arc::new(sort_plan)) + .map(LogicalPlan::Projection) + .map(Self::from) } /// Apply a union, preserving duplicate rows pub fn union(self, plan: LogicalPlan) -> Result { - Ok(Self::from(union(self.plan, plan)?)) + union(self.plan, plan).map(Self::from) } /// Apply a union, removing duplicate rows @@ -897,11 +893,9 @@ impl LogicalPlanBuilder { ) -> Result { let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; - Ok(Self::from(LogicalPlan::Aggregate(Aggregate::try_new( - Arc::new(self.plan), - group_expr, - aggr_expr, - )?))) + Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr) + .map(LogicalPlan::Aggregate) + .map(Self::from) } /// Create an expression to represent the explanation of the plan @@ -1159,8 +1153,8 @@ pub fn build_join_schema( ); let mut metadata = left.metadata().clone(); metadata.extend(right.metadata().clone()); - Ok(DFSchema::new_with_metadata(fields, metadata)? - .with_functional_dependencies(func_dependencies)) + DFSchema::new_with_metadata(fields, metadata) + .map(|schema| schema.with_functional_dependencies(func_dependencies)) } /// Errors if one or more expressions have equal names. @@ -1207,9 +1201,8 @@ pub fn project_with_column_index( }) .collect::>(); - Ok(LogicalPlan::Projection(Projection::try_new_with_schema( - alias_expr, input, schema, - )?)) + Projection::try_new_with_schema(alias_expr, input, schema) + .map(LogicalPlan::Projection) } /// Union two logical plans. @@ -1305,10 +1298,7 @@ pub fn project( } validate_unique_names("Projections", projected_expr.iter())?; - Ok(LogicalPlan::Projection(Projection::try_new( - projected_expr, - Arc::new(plan.clone()), - )?)) + Projection::try_new(projected_expr, Arc::new(plan)).map(LogicalPlan::Projection) } /// Create a SubqueryAlias to wrap a LogicalPlan. @@ -1316,9 +1306,7 @@ pub fn subquery_alias( plan: LogicalPlan, alias: impl Into, ) -> Result { - Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( - plan, alias, - )?)) + SubqueryAlias::try_new(plan, alias).map(LogicalPlan::SubqueryAlias) } /// Create a LogicalPlanBuilder representing a scan of a table with the provided name and schema. @@ -1481,11 +1469,11 @@ pub fn unnest_with_options( }) .collect::>(); - let schema = Arc::new( - DFSchema::new_with_metadata(fields, input_schema.metadata().clone())? - // We can use the existing functional dependencies: - .with_functional_dependencies(input_schema.functional_dependencies().clone()), - ); + let metadata = input_schema.metadata().clone(); + let df_schema = DFSchema::new_with_metadata(fields, metadata)?; + // We can use the existing functional dependencies: + let deps = input_schema.functional_dependencies().clone(); + let schema = Arc::new(df_schema.with_functional_dependencies(deps)); Ok(LogicalPlan::Unnest(Unnest { input: Arc::new(input), @@ -1497,11 +1485,9 @@ pub fn unnest_with_options( #[cfg(test)] mod tests { - use crate::logical_plan::StringifiedPlan; - use crate::{col, in_subquery, lit, scalar_subquery, sum}; - use crate::{expr, expr_fn::exists}; - use super::*; + use crate::logical_plan::StringifiedPlan; + use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery, sum}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{OwnedTableReference, SchemaError, TableReference}; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9dbf72d15a00..d62ac8926328 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -17,6 +17,13 @@ //! Logical plan types +use std::collections::{HashMap, HashSet}; +use std::fmt::{self, Debug, Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use super::dml::CopyTo; +use super::DdlStatement; use crate::dml::CopyOptions; use crate::expr::{Alias, Exists, InSubquery, Placeholder}; use crate::expr_rewriter::create_col_from_scalar_expr; @@ -28,15 +35,11 @@ use crate::utils::{ grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, }; use crate::{ - build_join_schema, Expr, ExprSchemable, TableProviderFilterPushDown, TableSource, -}; -use crate::{ - expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, LogicalPlanBuilder, Operator, + build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Expr, + ExprSchemable, LogicalPlanBuilder, Operator, TableProviderFilterPushDown, + TableSource, }; -use super::dml::CopyTo; -use super::DdlStatement; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, TreeNodeVisitor, @@ -51,11 +54,6 @@ use datafusion_common::{ pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; -use std::collections::{HashMap, HashSet}; -use std::fmt::{self, Debug, Display, Formatter}; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; - /// A LogicalPlan represents the different types of relational /// operators (such as Projection, Filter, etc) and can be created by /// the SQL query planner and the DataFrame API. @@ -534,10 +532,8 @@ impl LogicalPlan { // Schema of the projection may change // when its input changes. Hence we should use // `try_new` method instead of `try_new_with_schema`. - Ok(LogicalPlan::Projection(Projection::try_new( - projection.expr.to_vec(), - Arc::new(inputs[0].clone()), - )?)) + Projection::try_new(projection.expr.to_vec(), Arc::new(inputs[0].clone())) + .map(LogicalPlan::Projection) } LogicalPlan::Window(Window { window_expr, @@ -552,14 +548,15 @@ impl LogicalPlan { group_expr, aggr_expr, .. - }) => Ok(LogicalPlan::Aggregate(Aggregate::try_new( + }) => Aggregate::try_new( // Schema of the aggregate may change // when its input changes. Hence we should use // `try_new` method instead of `try_new_with_schema`. Arc::new(inputs[0].clone()), group_expr.to_vec(), aggr_expr.to_vec(), - )?)), + ) + .map(LogicalPlan::Aggregate), _ => self.with_new_exprs(self.expressions(), inputs), } } @@ -595,9 +592,10 @@ impl LogicalPlan { match self { // Since expr may be different than the previous expr, schema of the projection // may change. We need to use try_new method instead of try_new_with_schema method. - LogicalPlan::Projection(Projection { .. }) => Ok(LogicalPlan::Projection( - Projection::try_new(expr, Arc::new(inputs[0].clone()))?, - )), + LogicalPlan::Projection(Projection { .. }) => { + Projection::try_new(expr, Arc::new(inputs[0].clone())) + .map(LogicalPlan::Projection) + } LogicalPlan::Dml(DmlStatement { table_name, table_schema, @@ -673,10 +671,8 @@ impl LogicalPlan { let mut remove_aliases = RemoveAliases {}; let predicate = predicate.rewrite(&mut remove_aliases)?; - Ok(LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(inputs[0].clone()), - )?)) + Filter::try_new(predicate, Arc::new(inputs[0].clone())) + .map(LogicalPlan::Filter) } LogicalPlan::Repartition(Repartition { partitioning_scheme, @@ -715,11 +711,8 @@ impl LogicalPlan { // group exprs are the first expressions let agg_expr = expr.split_off(group_expr.len()); - Ok(LogicalPlan::Aggregate(Aggregate::try_new( - Arc::new(inputs[0].clone()), - expr, - agg_expr, - )?)) + Aggregate::try_new(Arc::new(inputs[0].clone()), expr, agg_expr) + .map(LogicalPlan::Aggregate) } LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort { expr, @@ -788,10 +781,8 @@ impl LogicalPlan { })) } LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => { - Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( - inputs[0].clone(), - alias.clone(), - )?)) + SubqueryAlias::try_new(inputs[0].clone(), alias.clone()) + .map(LogicalPlan::SubqueryAlias) } LogicalPlan::Limit(Limit { skip, fetch, .. }) => { Ok(LogicalPlan::Limit(Limit { diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 06581cbf434a..764dcffbced9 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -23,8 +23,9 @@ use crate::expr::{ ScalarUDF, Sort, TryCast, WindowFunction, }; use crate::{Expr, GetFieldAccess}; -use datafusion_common::tree_node::VisitRecursion; -use datafusion_common::{tree_node::TreeNode, Result}; + +use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::Result; impl TreeNode for Expr { fn apply_children(&self, op: &mut F) -> Result @@ -48,19 +49,18 @@ impl TreeNode for Expr { | Expr::Sort(Sort { expr, .. }) | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref().clone()], Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let mut exprs = match field{ + let expr = expr.as_ref().clone(); + match field { GetFieldAccess::ListIndex {key} => { - vec![key.as_ref().clone()] + vec![key.as_ref().clone(), expr] }, GetFieldAccess::ListRange {start, stop} => { - vec![start.as_ref().clone(), stop.as_ref().clone()] + vec![start.as_ref().clone(), stop.as_ref().clone(), expr] } GetFieldAccess::NamedStructField {name: _name} => { - vec![] + vec![expr] } - }; - exprs.push(expr.as_ref().clone()); - exprs + } } Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(), diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 4279f3343355..f028b721ce08 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -17,6 +17,8 @@ //! Coercion rules for matching argument types for binary operators +use crate::Operator; + use arrow::array::{new_empty_array, Array}; use arrow::compute::can_cast_types; use arrow::datatypes::{ @@ -24,10 +26,7 @@ use arrow::datatypes::{ DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; -use datafusion_common::{plan_datafusion_err, Result}; -use datafusion_common::{plan_err, DataFusionError}; - -use crate::Operator; +use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; /// The type signature of an instantiation of binary operator expression such as /// `lhs + rhs` diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 78ba1cfe6d7e..68a6a5607a1d 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -20,6 +20,8 @@ use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; +use crate::{utils, OptimizerConfig, OptimizerRule}; + use arrow::datatypes::DataType; use datafusion_common::tree_node::{ RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, @@ -28,13 +30,10 @@ use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::expr::Alias; -use datafusion_expr::{ - col, - logical_plan::{Aggregate, Filter, LogicalPlan, Projection, Sort, Window}, - Expr, ExprSchemable, +use datafusion_expr::logical_plan::{ + Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; - -use crate::{utils, OptimizerConfig, OptimizerRule}; +use datafusion_expr::{col, Expr, ExprSchemable}; /// A map from expression's identifier to tuple including /// - the expression itself (cloned) @@ -120,10 +119,8 @@ impl CommonSubexprEliminate { self.rewrite_expr(&[expr], &[&arrays], input, &expr_set, config)?; // Since projection expr changes, schema changes also. Use try_new method. - Ok(LogicalPlan::Projection(Projection::try_new( - pop_expr(&mut new_expr)?, - Arc::new(new_input), - )?)) + Projection::try_new(pop_expr(&mut new_expr)?, Arc::new(new_input)) + .map(LogicalPlan::Projection) } fn try_optimize_filter( @@ -242,11 +239,8 @@ impl CommonSubexprEliminate { if affected_id.is_empty() { // Since group_epxr changes, schema changes also. Use try_new method. - Ok(LogicalPlan::Aggregate(Aggregate::try_new( - Arc::new(new_input), - new_group_expr, - new_aggr_expr, - )?)) + Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr) + .map(LogicalPlan::Aggregate) } else { let mut agg_exprs = vec![]; diff --git a/datafusion/optimizer/src/merge_projection.rs b/datafusion/optimizer/src/merge_projection.rs index e8388b31655d..ec040cba6fe4 100644 --- a/datafusion/optimizer/src/merge_projection.rs +++ b/datafusion/optimizer/src/merge_projection.rs @@ -15,14 +15,15 @@ // specific language governing permissions and limitations // under the License. -use crate::optimizer::ApplyOrder; -use datafusion_common::Result; -use datafusion_expr::{Expr, LogicalPlan, Projection}; use std::collections::HashMap; +use crate::optimizer::ApplyOrder; use crate::push_down_filter::replace_cols_by_name; use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::Result; +use datafusion_expr::{Expr, LogicalPlan, Projection}; + /// Optimization rule that merge [LogicalPlan::Projection]. #[derive(Default)] pub struct MergeProjection; diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 038be22bc77c..897ae69f947a 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -18,11 +18,15 @@ //! Projection Push Down optimizer rule ensures that only referenced columns are //! loaded into memory +use std::collections::{BTreeSet, HashMap, HashSet}; +use std::sync::Arc; + use crate::eliminate_project::can_eliminate; use crate::merge_projection::merge_projection; use crate::optimizer::ApplyOrder; use crate::push_down_filter::replace_cols_by_name; use crate::{OptimizerConfig, OptimizerRule}; + use arrow::datatypes::DataType; use arrow::error::Result as ArrowResult; use datafusion_common::ScalarValue::UInt8; @@ -30,17 +34,11 @@ use datafusion_common::{ plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::expr::{AggregateFunction, Alias}; -use datafusion_expr::utils::exprlist_to_fields; use datafusion_expr::{ logical_plan::{Aggregate, LogicalPlan, Projection, TableScan, Union}, - utils::{expr_to_columns, exprlist_to_columns}, + utils::{expr_to_columns, exprlist_to_columns, exprlist_to_fields}, Expr, LogicalPlanBuilder, SubqueryAlias, }; -use std::collections::HashMap; -use std::{ - collections::{BTreeSet, HashSet}, - sync::Arc, -}; // if projection is empty return projection-new_plan, else return new_plan. #[macro_export] @@ -598,14 +596,14 @@ fn push_down_scan( projection.into_iter().collect::>() }; - let table_scan = TableScan::try_new( + TableScan::try_new( scan.table_name.clone(), scan.source.clone(), Some(projection), scan.filters.clone(), scan.fetch, - )?; - Ok(LogicalPlan::TableScan(table_scan)) + ) + .map(LogicalPlan::TableScan) } fn restrict_outputs( @@ -625,25 +623,25 @@ fn restrict_outputs( #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::vec; + use super::*; use crate::eliminate_project::EliminateProjection; use crate::optimizer::Optimizer; use crate::test::*; use crate::OptimizerContext; + use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion_common::DFSchema; use datafusion_expr::builder::table_scan_with_filters; - use datafusion_expr::expr; - use datafusion_expr::expr::Cast; - use datafusion_expr::WindowFrame; - use datafusion_expr::WindowFunction; + use datafusion_expr::expr::{self, Cast}; + use datafusion_expr::logical_plan::{ + builder::LogicalPlanBuilder, table_scan, JoinType, + }; use datafusion_expr::{ - col, count, lit, - logical_plan::{builder::LogicalPlanBuilder, table_scan, JoinType}, - max, min, AggregateFunction, Expr, + col, count, lit, max, min, AggregateFunction, Expr, WindowFrame, WindowFunction, }; - use std::collections::HashMap; - use std::vec; #[test] fn aggregate_no_group_by() -> Result<()> { diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 9bbbc52aa5db..540617b77084 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::optimizer::ApplyOrder; +use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp}; use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::Result; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::Distinct; -use datafusion_expr::{Aggregate, LogicalPlan}; -use ApplyOrder::BottomUp; +use datafusion_expr::{Aggregate, Distinct, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] /// diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index cac373f50ea7..8e0f93cb5781 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -17,8 +17,11 @@ //! single distinct to group by optimizer rule +use std::sync::Arc; + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::{DFSchema, Result}; use datafusion_expr::{ col, @@ -27,8 +30,8 @@ use datafusion_expr::{ utils::columnize_expr, Expr, ExprSchemable, }; + use hashbrown::HashSet; -use std::sync::Arc; /// single distinct to group by optimizer rule /// ```text @@ -102,6 +105,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .. }) => { if is_single_distinct_agg(plan)? && !contains_grouping_set(group_expr) { + let fields = schema.fields(); // alias all original group_by exprs let (mut inner_group_exprs, out_group_expr_with_alias): ( Vec, @@ -110,37 +114,34 @@ impl OptimizerRule for SingleDistinctToGroupBy { .iter() .enumerate() .map(|(i, group_expr)| { - let (alias_expr, out_group_expr, original_name) = - if let Expr::Column(_) = group_expr { - // For Column expressions we can use existing expression as is. - (group_expr.clone(), group_expr.clone(), None) - } else { - // For complex expression write is as alias, to be able to refer - // if from parent operators successfully. - // Consider plan below. - // - // Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ - // --Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ - // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] - // - // First aggregate(from bottom) refers to `test.a` column. - // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate. - // If we were to write plan above as below without alias - // - // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ - // --Aggregate: groupBy=[[test.a + Int32(1), test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ - // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] - // - // Second aggregate refers to the `test.a + Int32(1)` expression However, its input do not have `test.a` expression in it. - let alias_str = format!("group_alias_{i}"); - let alias_expr = group_expr.clone().alias(&alias_str); - ( - alias_expr, - col(alias_str), - Some(schema.fields()[i].qualified_name()), - ) - }; - (alias_expr, (out_group_expr, original_name)) + if let Expr::Column(_) = group_expr { + // For Column expressions we can use existing expression as is. + (group_expr.clone(), (group_expr.clone(), None)) + } else { + // For complex expression write is as alias, to be able to refer + // if from parent operators successfully. + // Consider plan below. + // + // Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ + // --Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ + // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] + // + // First aggregate(from bottom) refers to `test.a` column. + // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate. + // If we were to write plan above as below without alias + // + // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ + // --Aggregate: groupBy=[[test.a + Int32(1), test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ + // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] + // + // Second aggregate refers to the `test.a + Int32(1)` expression However, its input do not have `test.a` expression in it. + let alias_str = format!("group_alias_{i}"); + let alias_expr = group_expr.clone().alias(&alias_str); + ( + alias_expr, + (col(alias_str), Some(fields[i].qualified_name())), + ) + } }) .unzip(); @@ -209,24 +210,22 @@ impl OptimizerRule for SingleDistinctToGroupBy { // this optimizer has two kinds of alias: // - group_by aggr // - aggr expr - let mut alias_expr: Vec = Vec::new(); - for (group_expr, original_field) in out_group_expr_with_alias { - let expr = if let Some(name) = original_field { - group_expr.alias(name) - } else { - group_expr - }; - alias_expr.push(expr); - } - for (i, expr) in new_aggr_exprs.iter().enumerate() { - alias_expr.push(columnize_expr( - expr.clone().alias( - schema.clone().fields()[i + group_expr.len()] - .qualified_name(), - ), - &outer_aggr_schema, - )); - } + let group_size = group_expr.len(); + let alias_expr = out_group_expr_with_alias + .into_iter() + .map(|(group_expr, original_field)| { + if let Some(name) = original_field { + group_expr.alias(name) + } else { + group_expr + } + }) + .chain(new_aggr_exprs.iter().enumerate().map(|(idx, expr)| { + let idx = idx + group_size; + let name = fields[idx].qualified_name(); + columnize_expr(expr.clone().alias(name), &outer_aggr_schema) + })) + .collect(); let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new( Arc::new(inner_agg), diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 49efde72cbc5..af4612272676 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -17,18 +17,22 @@ //! Array expressions +use std::any::type_name; +use std::sync::Arc; + use arrow::array::*; use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::datatypes::{DataType, Field, UInt64Type}; use arrow_buffer::NullBuffer; -use core::any::type_name; use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_list_array}; -use datafusion_common::{exec_err, internal_err, not_impl_err, plan_err, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::utils::wrap_into_list_array; +use datafusion_common::{ + exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, +}; use datafusion_expr::ColumnarValue; + use itertools::Itertools; -use std::sync::Arc; macro_rules! downcast_arg { ($ARG:expr, $ARRAY_TYPE:ident) => {{ @@ -400,34 +404,26 @@ fn array(values: &[ColumnarValue]) -> Result { .iter() .map(|x| match x { ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), + ColumnarValue::Scalar(scalar) => scalar.to_array(), }) .collect(); - let mut data_type = None; + let mut data_type = DataType::Null; for arg in &arrays { let arg_data_type = arg.data_type(); if !arg_data_type.equals_datatype(&DataType::Null) { - data_type = Some(arg_data_type.clone()); + data_type = arg_data_type.clone(); break; - } else { - data_type = Some(DataType::Null); } } match data_type { - // empty array - None => { - let list_arr = ScalarValue::new_list(&[], &DataType::Null); - Ok(Arc::new(list_arr)) - } - // all nulls, return a NullArray. - Some(DataType::Null) => { - let null_arr = vec![ScalarValue::Null; arrays.len()]; - let list_arr = ScalarValue::new_list(null_arr.as_slice(), &DataType::Null); - Ok(Arc::new(list_arr)) + // Either an empty array or all nulls: + DataType::Null => { + let array = new_null_array(&DataType::Null, arrays.len()); + Ok(Arc::new(wrap_into_list_array(array))) } - Some(data_type) => Ok(array_array(arrays.as_slice(), data_type)?), + data_type => array_array(arrays.as_slice(), data_type), } } From a8cd920a2fa76ed67426f8bb37c3f89c9694e8d6 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Tue, 24 Oct 2023 16:13:26 +0300 Subject: [PATCH 13/15] Improve comments --- datafusion/expr/src/built_in_function.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index c576afb37beb..b2262a107143 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -366,8 +366,9 @@ fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { get_wider_type(lhs_field.data_type(), rhs_field.data_type())?; if lhs_field.name() != rhs_field.name() { return Err(exec_datafusion_err!( - "There is no wider type (that can represent other) among lhs: {:?}, rhs:{:?}", - lhs, rhs + "There is no wider type that can represent both lhs: {:?}, rhs:{:?}", + lhs, + rhs )); } assert_eq!(lhs_field.name(), rhs_field.name()); @@ -377,8 +378,9 @@ fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { } (_, _) => { return Err(exec_datafusion_err!( - "There is no wider type (that can represent other) among lhs: {:?}, rhs:{:?}", - lhs, rhs + "There is no wider type that can represent both lhs: {:?}, rhs:{:?}", + lhs, + rhs )); } }) From 0741b391754ef862c7e8415e9c03867a0f725be2 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 24 Oct 2023 17:56:13 +0300 Subject: [PATCH 14/15] Move get_wider_type to type_coercion module --- datafusion/expr/src/built_in_function.rs | 70 +------------------- datafusion/expr/src/type_coercion/binary.rs | 73 ++++++++++++++++++++- 2 files changed, 73 insertions(+), 70 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index b2262a107143..e2420bce9e8a 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -33,10 +33,10 @@ use crate::{ use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; use datafusion_common::{ - exec_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, - Result, + internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, }; +use crate::type_coercion::binary::get_wider_type; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -320,72 +320,6 @@ fn function_to_name() -> &'static HashMap { }) } -/// Returns the wider type among lhs and rhs. -/// Wider type is the type that can safely represent the other type without information loss. -/// Returns Error if types are incompatible. -fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { - Ok(match (lhs, rhs) { - (lhs, rhs) if lhs == rhs => lhs.clone(), - (DataType::Null, _) => rhs.clone(), - (_, DataType::Null) => lhs.clone(), - // Right UInt is larger than left UInt - (DataType::UInt8, DataType::UInt16 | DataType::UInt32 | DataType::UInt64) => { - rhs.clone() - } - (DataType::UInt16, DataType::UInt32 | DataType::UInt64) => rhs.clone(), - (DataType::UInt32, DataType::UInt64) => rhs.clone(), - // Left UInt is larger than right UInt. - (DataType::UInt16 | DataType::UInt32 | DataType::UInt64, DataType::UInt8) => { - lhs.clone() - } - (DataType::UInt32 | DataType::UInt64, DataType::UInt16) => lhs.clone(), - (DataType::UInt64, DataType::UInt32) => lhs.clone(), - // Right Int is larger than left Int - (DataType::Int8, DataType::Int16 | DataType::Int32 | DataType::Int64) => { - rhs.clone() - } - (DataType::Int16, DataType::Int32 | DataType::Int64) => rhs.clone(), - (DataType::Int32, DataType::Int64) => rhs.clone(), - // Left Int is larger than right Int. - (DataType::Int16 | DataType::Int32 | DataType::Int64, DataType::Int8) => { - lhs.clone() - } - (DataType::Int32 | DataType::Int64, DataType::Int16) => lhs.clone(), - (DataType::Int64, DataType::Int32) => lhs.clone(), - // Right Float is larger than left Float - (DataType::Float16, DataType::Float32 | DataType::Float64) => rhs.clone(), - (DataType::Float32, DataType::Float64) => rhs.clone(), - // Left Float is larger than right Float. - (DataType::Float32 | DataType::Float64, DataType::Float16) => lhs.clone(), - (DataType::Float64, DataType::Float32) => lhs.clone(), - // String - (DataType::Utf8, DataType::LargeUtf8) => rhs.clone(), - (DataType::LargeUtf8, DataType::Utf8) => lhs.clone(), - (DataType::List(lhs_field), DataType::List(rhs_field)) => { - let field_type = - get_wider_type(lhs_field.data_type(), rhs_field.data_type())?; - if lhs_field.name() != rhs_field.name() { - return Err(exec_datafusion_err!( - "There is no wider type that can represent both lhs: {:?}, rhs:{:?}", - lhs, - rhs - )); - } - assert_eq!(lhs_field.name(), rhs_field.name()); - let field_name = lhs_field.name(); - let nullable = lhs_field.is_nullable() | rhs_field.is_nullable(); - DataType::List(Arc::new(Field::new(field_name, field_type, nullable))) - } - (_, _) => { - return Err(exec_datafusion_err!( - "There is no wider type that can represent both lhs: {:?}, rhs:{:?}", - lhs, - rhs - )); - } - }) -} - impl BuiltinScalarFunction { /// an allowlist of functions to take zero arguments, so that they will get special treatment /// while executing. diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index f028b721ce08..5bf7360899b2 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -18,15 +18,18 @@ //! Coercion rules for matching argument types for binary operators use crate::Operator; +use std::sync::Arc; use arrow::array::{new_empty_array, Array}; use arrow::compute::can_cast_types; use arrow::datatypes::{ - DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DataType, Field, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; -use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; +use datafusion_common::{ + exec_datafusion_err, plan_datafusion_err, plan_err, DataFusionError, Result, +}; /// The type signature of an instantiation of binary operator expression such as /// `lhs + rhs` @@ -475,6 +478,72 @@ fn get_wider_decimal_type( } } +/// Returns the wider type among lhs and rhs. +/// Wider type is the type that can safely represent the other type without information loss. +/// Returns Error if types are incompatible. +pub fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { + Ok(match (lhs, rhs) { + (lhs, rhs) if lhs == rhs => lhs.clone(), + (DataType::Null, _) => rhs.clone(), + (_, DataType::Null) => lhs.clone(), + // Right UInt is larger than left UInt + (DataType::UInt8, DataType::UInt16 | DataType::UInt32 | DataType::UInt64) => { + rhs.clone() + } + (DataType::UInt16, DataType::UInt32 | DataType::UInt64) => rhs.clone(), + (DataType::UInt32, DataType::UInt64) => rhs.clone(), + // Left UInt is larger than right UInt. + (DataType::UInt16 | DataType::UInt32 | DataType::UInt64, DataType::UInt8) => { + lhs.clone() + } + (DataType::UInt32 | DataType::UInt64, DataType::UInt16) => lhs.clone(), + (DataType::UInt64, DataType::UInt32) => lhs.clone(), + // Right Int is larger than left Int + (DataType::Int8, DataType::Int16 | DataType::Int32 | DataType::Int64) => { + rhs.clone() + } + (DataType::Int16, DataType::Int32 | DataType::Int64) => rhs.clone(), + (DataType::Int32, DataType::Int64) => rhs.clone(), + // Left Int is larger than right Int. + (DataType::Int16 | DataType::Int32 | DataType::Int64, DataType::Int8) => { + lhs.clone() + } + (DataType::Int32 | DataType::Int64, DataType::Int16) => lhs.clone(), + (DataType::Int64, DataType::Int32) => lhs.clone(), + // Right Float is larger than left Float + (DataType::Float16, DataType::Float32 | DataType::Float64) => rhs.clone(), + (DataType::Float32, DataType::Float64) => rhs.clone(), + // Left Float is larger than right Float. + (DataType::Float32 | DataType::Float64, DataType::Float16) => lhs.clone(), + (DataType::Float64, DataType::Float32) => lhs.clone(), + // String + (DataType::Utf8, DataType::LargeUtf8) => rhs.clone(), + (DataType::LargeUtf8, DataType::Utf8) => lhs.clone(), + (DataType::List(lhs_field), DataType::List(rhs_field)) => { + let field_type = + get_wider_type(lhs_field.data_type(), rhs_field.data_type())?; + if lhs_field.name() != rhs_field.name() { + return Err(exec_datafusion_err!( + "There is no wider type that can represent both lhs: {:?}, rhs:{:?}", + lhs, + rhs + )); + } + assert_eq!(lhs_field.name(), rhs_field.name()); + let field_name = lhs_field.name(); + let nullable = lhs_field.is_nullable() | rhs_field.is_nullable(); + DataType::List(Arc::new(Field::new(field_name, field_type, nullable))) + } + (_, _) => { + return Err(exec_datafusion_err!( + "There is no wider type that can represent both lhs: {:?}, rhs:{:?}", + lhs, + rhs + )); + } + }) +} + /// Convert the numeric data type to the decimal data type. /// Now, we just support the signed integer type and floating-point type. fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option { From fdee7c352706f0f7bb04c23cae2a9f3294b9a59a Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Tue, 24 Oct 2023 20:45:40 +0300 Subject: [PATCH 15/15] Clean up type coercion file --- datafusion/expr/src/built_in_function.rs | 2 +- datafusion/expr/src/type_coercion/binary.rs | 152 ++++++++------------ 2 files changed, 60 insertions(+), 94 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index e2420bce9e8a..16554133d828 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -25,6 +25,7 @@ use std::sync::{Arc, OnceLock}; use crate::nullif::SUPPORTED_NULLIF_TYPES; use crate::signature::TIMEZONE_WILDCARD; +use crate::type_coercion::binary::get_wider_type; use crate::type_coercion::functions::data_types; use crate::{ conditional_expressions, struct_expressions, utils, FuncMonotonicity, Signature, @@ -36,7 +37,6 @@ use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, }; -use crate::type_coercion::binary::get_wider_type; use strum::IntoEnumIterator; use strum_macros::EnumIter; diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 5bf7360899b2..cf93d15e23f0 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -17,9 +17,10 @@ //! Coercion rules for matching argument types for binary operators -use crate::Operator; use std::sync::Arc; +use crate::Operator; + use arrow::array::{new_empty_array, Array}; use arrow::compute::can_cast_types; use arrow::datatypes::{ @@ -67,61 +68,54 @@ impl Signature { /// Returns a [`Signature`] for applying `op` to arguments of type `lhs` and `rhs` fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result { + use arrow::datatypes::DataType::*; + use Operator::*; match op { - Operator::Eq | - Operator::NotEq | - Operator::Lt | - Operator::LtEq | - Operator::Gt | - Operator::GtEq | - Operator::IsDistinctFrom | - Operator::IsNotDistinctFrom => { + Eq | + NotEq | + Lt | + LtEq | + Gt | + GtEq | + IsDistinctFrom | + IsNotDistinctFrom => { comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { plan_datafusion_err!( "Cannot infer common argument type for comparison operation {lhs} {op} {rhs}" ) }) } - Operator::And | Operator::Or => match (lhs, rhs) { - // logical binary boolean operators can only be evaluated in bools or nulls - (DataType::Boolean, DataType::Boolean) - | (DataType::Null, DataType::Null) - | (DataType::Boolean, DataType::Null) - | (DataType::Null, DataType::Boolean) => Ok(Signature::uniform(DataType::Boolean)), - _ => plan_err!( + And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) { + // Logical binary boolean operators can only be evaluated for + // boolean or null arguments. + Ok(Signature::uniform(DataType::Boolean)) + } else { + plan_err!( "Cannot infer common argument type for logical boolean operation {lhs} {op} {rhs}" - ), - }, - Operator::RegexMatch | - Operator::RegexIMatch | - Operator::RegexNotMatch | - Operator::RegexNotIMatch => { + ) + } + RegexMatch | RegexIMatch | RegexNotMatch | RegexNotIMatch => { regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { plan_datafusion_err!( "Cannot infer common argument type for regex operation {lhs} {op} {rhs}" ) }) } - Operator::BitwiseAnd - | Operator::BitwiseOr - | Operator::BitwiseXor - | Operator::BitwiseShiftRight - | Operator::BitwiseShiftLeft => { + BitwiseAnd | BitwiseOr | BitwiseXor | BitwiseShiftRight | BitwiseShiftLeft => { bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { plan_datafusion_err!( "Cannot infer common type for bitwise operation {lhs} {op} {rhs}" ) }) } - Operator::StringConcat => { + StringConcat => { string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { plan_datafusion_err!( "Cannot infer common string type for string concat operation {lhs} {op} {rhs}" ) }) } - Operator::AtArrow - | Operator::ArrowAt => { + AtArrow | ArrowAt => { // ArrowAt and AtArrow check for whether one array ic contained in another. // The result type is boolean. Signature::comparison defines this signature. // Operation has nothing to do with comparison @@ -131,22 +125,18 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result ) }) } - Operator::Plus | - Operator::Minus | - Operator::Multiply | - Operator::Divide| - Operator::Modulo => { + Plus | Minus | Multiply | Divide | Modulo => { let get_result = |lhs, rhs| { use arrow::compute::kernels::numeric::*; let l = new_empty_array(lhs); let r = new_empty_array(rhs); let result = match op { - Operator::Plus => add_wrapping(&l, &r), - Operator::Minus => sub_wrapping(&l, &r), - Operator::Multiply => mul_wrapping(&l, &r), - Operator::Divide => div(&l, &r), - Operator::Modulo => rem(&l, &r), + Plus => add_wrapping(&l, &r), + Minus => sub_wrapping(&l, &r), + Multiply => mul_wrapping(&l, &r), + Divide => div(&l, &r), + Modulo => rem(&l, &r), _ => unreachable!(), }; result.map(|x| x.data_type().clone()) @@ -233,7 +223,7 @@ fn math_decimal_coercion( (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => { Some((dec_type.clone(), dec_type.clone())) } - (Decimal128(_, _), Decimal128(_, _)) => { + (Decimal128(_, _), Decimal128(_, _)) | (Decimal256(_, _), Decimal256(_, _)) => { Some((lhs_type.clone(), rhs_type.clone())) } // Unlike with comparison we don't coerce to a decimal in the case of floating point @@ -244,9 +234,6 @@ fn math_decimal_coercion( (Int8 | Int16 | Int32 | Int64, Decimal128(_, _)) => { Some((coerce_numeric_type_to_decimal(lhs_type)?, rhs_type.clone())) } - (Decimal256(_, _), Decimal256(_, _)) => { - Some((lhs_type.clone(), rhs_type.clone())) - } (Decimal256(_, _), Int8 | Int16 | Int32 | Int64) => Some(( lhs_type.clone(), coerce_numeric_type_to_decimal256(rhs_type)?, @@ -478,67 +465,49 @@ fn get_wider_decimal_type( } } -/// Returns the wider type among lhs and rhs. -/// Wider type is the type that can safely represent the other type without information loss. -/// Returns Error if types are incompatible. +/// Returns the wider type among arguments `lhs` and `rhs`. +/// The wider type is the type that can safely represent values from both types +/// without information loss. Returns an Error if types are incompatible. pub fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { + use arrow::datatypes::DataType::*; Ok(match (lhs, rhs) { (lhs, rhs) if lhs == rhs => lhs.clone(), - (DataType::Null, _) => rhs.clone(), - (_, DataType::Null) => lhs.clone(), - // Right UInt is larger than left UInt - (DataType::UInt8, DataType::UInt16 | DataType::UInt32 | DataType::UInt64) => { - rhs.clone() - } - (DataType::UInt16, DataType::UInt32 | DataType::UInt64) => rhs.clone(), - (DataType::UInt32, DataType::UInt64) => rhs.clone(), + // Right UInt is larger than left UInt. + (UInt8, UInt16 | UInt32 | UInt64) | (UInt16, UInt32 | UInt64) | (UInt32, UInt64) | + // Right Int is larger than left Int. + (Int8, Int16 | Int32 | Int64) | (Int16, Int32 | Int64) | (Int32, Int64) | + // Right Float is larger than left Float. + (Float16, Float32 | Float64) | (Float32, Float64) | + // Right String is larger than left String. + (Utf8, LargeUtf8) | + // Any right type is wider than a left hand side Null. + (Null, _) => rhs.clone(), // Left UInt is larger than right UInt. - (DataType::UInt16 | DataType::UInt32 | DataType::UInt64, DataType::UInt8) => { - lhs.clone() - } - (DataType::UInt32 | DataType::UInt64, DataType::UInt16) => lhs.clone(), - (DataType::UInt64, DataType::UInt32) => lhs.clone(), - // Right Int is larger than left Int - (DataType::Int8, DataType::Int16 | DataType::Int32 | DataType::Int64) => { - rhs.clone() - } - (DataType::Int16, DataType::Int32 | DataType::Int64) => rhs.clone(), - (DataType::Int32, DataType::Int64) => rhs.clone(), + (UInt16 | UInt32 | UInt64, UInt8) | (UInt32 | UInt64, UInt16) | (UInt64, UInt32) | // Left Int is larger than right Int. - (DataType::Int16 | DataType::Int32 | DataType::Int64, DataType::Int8) => { - lhs.clone() - } - (DataType::Int32 | DataType::Int64, DataType::Int16) => lhs.clone(), - (DataType::Int64, DataType::Int32) => lhs.clone(), - // Right Float is larger than left Float - (DataType::Float16, DataType::Float32 | DataType::Float64) => rhs.clone(), - (DataType::Float32, DataType::Float64) => rhs.clone(), + (Int16 | Int32 | Int64, Int8) | (Int32 | Int64, Int16) | (Int64, Int32) | // Left Float is larger than right Float. - (DataType::Float32 | DataType::Float64, DataType::Float16) => lhs.clone(), - (DataType::Float64, DataType::Float32) => lhs.clone(), - // String - (DataType::Utf8, DataType::LargeUtf8) => rhs.clone(), - (DataType::LargeUtf8, DataType::Utf8) => lhs.clone(), - (DataType::List(lhs_field), DataType::List(rhs_field)) => { + (Float32 | Float64, Float16) | (Float64, Float32) | + // Left String is larget than right String. + (LargeUtf8, Utf8) | + // Any left type is wider than a right hand side Null. + (_, Null) => lhs.clone(), + (List(lhs_field), List(rhs_field)) => { let field_type = get_wider_type(lhs_field.data_type(), rhs_field.data_type())?; if lhs_field.name() != rhs_field.name() { return Err(exec_datafusion_err!( - "There is no wider type that can represent both lhs: {:?}, rhs:{:?}", - lhs, - rhs + "There is no wider type that can represent both {lhs} and {rhs}." )); } assert_eq!(lhs_field.name(), rhs_field.name()); let field_name = lhs_field.name(); let nullable = lhs_field.is_nullable() | rhs_field.is_nullable(); - DataType::List(Arc::new(Field::new(field_name, field_type, nullable))) + List(Arc::new(Field::new(field_name, field_type, nullable))) } (_, _) => { return Err(exec_datafusion_err!( - "There is no wider type that can represent both lhs: {:?}, rhs:{:?}", - lhs, - rhs + "There is no wider type that can represent both {lhs} and {rhs}." )); } }) @@ -879,14 +848,11 @@ fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { #[cfg(test)] mod tests { - use arrow::datatypes::DataType; - - use datafusion_common::assert_contains; - use datafusion_common::Result; - + use super::*; use crate::Operator; - use super::*; + use arrow::datatypes::DataType; + use datafusion_common::{assert_contains, Result}; #[test] fn test_coercion_error() -> Result<()> {