diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 3e696d6b93a6..778950cbf926 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -63,6 +63,7 @@ cargo run --example csv_sql - [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file - [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): Build and run a query plan from a SQL statement against multiple local Parquet files - ['parquet_exec_visitor.rs'](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution +- [`plan_to_sql.rs`](examples/plan_to_sql.rs): Generate SQL from Datafusion `Expr` and `LogicalPlan` - [`pruning.rs`](examples/parquet_sql.rs): Use pruning to rule out files based on statistics - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 342a23b6e73d..cf284472212f 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -31,8 +31,8 @@ use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::{cast::as_float64_array, ScalarValue}; use datafusion_expr::{ - function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl, - GroupsAccumulator, Signature, + function::{AccumulatorArgs, StateFieldsArgs}, + Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, }; /// This example shows how to use the full AggregateUDFImpl API to implement a user @@ -92,21 +92,16 @@ impl AggregateUDFImpl for GeoMeanUdaf { } /// This is the description of the state. accumulator's state() must match the types here. - fn state_fields( - &self, - _name: &str, - value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ - Field::new("prod", value_type, true), + Field::new("prod", args.return_type.clone(), true), Field::new("n", DataType::UInt32, true), ]) } /// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator` /// which is used for cases when there are grouping columns in the query - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs index c8063c0eb1e3..d1ef1c6c9dd0 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/advanced_udf.rs @@ -15,26 +15,21 @@ // specific language governing permissions and limitations // under the License. -use datafusion::{ - arrow::{ - array::{ArrayRef, Float32Array, Float64Array}, - datatypes::DataType, - record_batch::RecordBatch, - }, - logical_expr::Volatility, -}; use std::any::Any; +use std::sync::Arc; -use arrow::array::{new_null_array, Array, AsArray}; +use arrow::array::{ + new_null_array, Array, ArrayRef, AsArray, Float32Array, Float64Array, +}; use arrow::compute; -use arrow::datatypes::Float64Type; +use arrow::datatypes::{DataType, Float64Type}; +use arrow::record_batch::RecordBatch; use datafusion::error::Result; +use datafusion::logical_expr::Volatility; use datafusion::prelude::*; use datafusion_common::{internal_err, ScalarValue}; -use datafusion_expr::{ - ColumnarValue, FuncMonotonicity, ScalarUDF, ScalarUDFImpl, Signature, -}; -use std::sync::Arc; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; /// This example shows how to use the full ScalarUDFImpl API to implement a user /// defined function. As in the `simple_udf.rs` example, this struct implements @@ -186,8 +181,9 @@ impl ScalarUDFImpl for PowUdf { &self.aliases } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![Some(true)])) + fn monotonicity(&self, input: &[ExprProperties]) -> Result { + // The POW function preserves the order of its argument. + Ok(input[0].sort_properties) } } diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs index 3973e50474ba..9e624b66294d 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/function_factory.rs @@ -15,17 +15,18 @@ // specific language governing permissions and limitations // under the License. +use std::result::Result as RResult; +use std::sync::Arc; + use datafusion::error::Result; use datafusion::execution::context::{ FunctionFactory, RegisterFunction, SessionContext, SessionState, }; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{exec_err, internal_err, DataFusionError}; -use datafusion_expr::simplify::ExprSimplifyResult; -use datafusion_expr::simplify::SimplifyInfo; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{CreateFunction, Expr, ScalarUDF, ScalarUDFImpl, Signature}; -use std::result::Result as RResult; -use std::sync::Arc; /// This example shows how to utilize [FunctionFactory] to implement simple /// SQL-macro like functions using a `CREATE FUNCTION` statement. The same @@ -156,8 +157,8 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { &[] } - fn monotonicity(&self) -> Result> { - Ok(None) + fn monotonicity(&self, _input: &[ExprProperties]) -> Result { + Ok(SortProperties::Unordered) } } diff --git a/datafusion-examples/examples/plan_to_sql.rs b/datafusion-examples/examples/plan_to_sql.rs new file mode 100644 index 000000000000..3915d3991f76 --- /dev/null +++ b/datafusion-examples/examples/plan_to_sql.rs @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::error::Result; + +use datafusion::prelude::*; +use datafusion::sql::unparser::expr_to_sql; +use datafusion_sql::unparser::dialect::CustomDialect; +use datafusion_sql::unparser::Unparser; + +/// This example demonstrates the programmatic construction of +/// SQL using the DataFusion Expr [`Expr`] and LogicalPlan [`LogicalPlan`] API. +/// +/// +/// The code in this example shows how to: +/// 1. Create SQL from a variety of Expr and LogicalPlan: [`main`]` +/// 2. Create a simple expression [`Exprs`] with fluent API +/// and convert to sql: [`simple_expr_to_sql_demo`] +/// 3. Create a simple expression [`Exprs`] with fluent API +/// and convert to sql without escaping column names: [`simple_expr_to_sql_demo_no_escape`] +/// 4. Create a simple expression [`Exprs`] with fluent API +/// and convert to sql escaping column names a MySQL style: [`simple_expr_to_sql_demo_escape_mysql_style`] + +#[tokio::main] +async fn main() -> Result<()> { + // See how to evaluate expressions + simple_expr_to_sql_demo()?; + simple_expr_to_sql_demo_no_escape()?; + simple_expr_to_sql_demo_escape_mysql_style()?; + Ok(()) +} + +/// DataFusion can convert expressions to SQL, using column name escaping +/// PostgreSQL style. +fn simple_expr_to_sql_demo() -> Result<()> { + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + let ast = expr_to_sql(&expr)?; + let sql = format!("{}", ast); + assert_eq!(sql, r#"(("a" < 5) OR ("a" = 8))"#); + Ok(()) +} + +/// DataFusion can convert expressions to SQL without escaping column names using +/// using a custom dialect and an explicit unparser +fn simple_expr_to_sql_demo_no_escape() -> Result<()> { + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + let dialect = CustomDialect::new(None); + let unparser = Unparser::new(&dialect); + let ast = unparser.expr_to_sql(&expr)?; + let sql = format!("{}", ast); + assert_eq!(sql, r#"((a < 5) OR (a = 8))"#); + Ok(()) +} + +/// DataFusion can convert expressions to SQL without escaping column names using +/// using a custom dialect and an explicit unparser +fn simple_expr_to_sql_demo_escape_mysql_style() -> Result<()> { + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + let dialect = CustomDialect::new(Some('`')); + let unparser = Unparser::new(&dialect); + let ast = unparser.expr_to_sql(&expr)?; + let sql = format!("{}", ast); + assert_eq!(sql, r#"((`a` < 5) OR (`a` = 8))"#); + Ok(()) +} diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs index 92deb20272e4..08b6bcab0190 100644 --- a/datafusion-examples/examples/simplify_udaf_expression.rs +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -17,7 +17,7 @@ use arrow_schema::{Field, Schema}; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; -use datafusion_expr::function::AggregateFunctionSimplification; +use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs}; use datafusion_expr::simplify::SimplifyInfo; use std::{any::Any, sync::Arc}; @@ -70,16 +70,11 @@ impl AggregateUDFImpl for BetterAvgUdaf { unimplemented!("should not be invoked") } - fn state_fields( - &self, - _name: &str, - _value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!("should not be invoked") } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 9d58465191e1..8c6790541597 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -40,7 +40,7 @@ use arrow::ipc::reader::FileReader; use arrow::ipc::writer::IpcWriteOptions; use arrow::ipc::{root_as_message, CompressionType}; use arrow_schema::{ArrowError, Schema, SchemaRef}; -use datafusion_common::{not_impl_err, DataFusionError, FileType, Statistics}; +use datafusion_common::{not_impl_err, DataFusionError, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; @@ -136,10 +136,6 @@ impl FileFormat for ArrowFormat { order_requirements, )) as _) } - - fn file_type(&self) -> FileType { - FileType::ARROW - } } /// Implements [`DataSink`] for writing to arrow_ipc files diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index 132dae14c684..7b2c26a2c4f9 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -23,7 +23,6 @@ use std::sync::Arc; use arrow::datatypes::Schema; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use datafusion_common::FileType; use datafusion_physical_expr::PhysicalExpr; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; @@ -89,10 +88,6 @@ impl FileFormat for AvroFormat { let exec = AvroExec::new(conf); Ok(Arc::new(exec)) } - - fn file_type(&self) -> FileType { - FileType::AVRO - } } #[cfg(test)] diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 17bc7aafce85..ae5ac52025cf 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -42,7 +42,7 @@ use arrow::datatypes::SchemaRef; use arrow::datatypes::{DataType, Field, Fields, Schema}; use datafusion_common::config::CsvOptions; use datafusion_common::file_options::csv_writer::CsvWriterOptions; -use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; +use datafusion_common::{exec_err, not_impl_err, DataFusionError}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::metrics::MetricsSet; @@ -280,10 +280,6 @@ impl FileFormat for CsvFormat { order_requirements, )) as _) } - - fn file_type(&self) -> FileType { - FileType::CSV - } } impl CsvFormat { @@ -549,8 +545,9 @@ mod tests { use arrow::compute::concat_batches; use datafusion_common::cast::as_string_array; + use datafusion_common::internal_err; use datafusion_common::stats::Precision; - use datafusion_common::{internal_err, GetExt}; + use datafusion_common::{FileType, GetExt}; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::{col, lit}; diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 9f526e1c87b4..6e6c79848594 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -43,7 +43,7 @@ use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter}; use arrow_array::RecordBatch; use datafusion_common::config::JsonOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; -use datafusion_common::{not_impl_err, FileType}; +use datafusion_common::not_impl_err; use datafusion_execution::TaskContext; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::metrics::MetricsSet; @@ -184,10 +184,6 @@ impl FileFormat for JsonFormat { order_requirements, )) as _) } - - fn file_type(&self) -> FileType { - FileType::JSON - } } impl Default for JsonSerializer { diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index fdb89a264951..243a91b7437b 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -41,7 +41,7 @@ use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::{ExecutionPlan, Statistics}; -use datafusion_common::{not_impl_err, FileType}; +use datafusion_common::not_impl_err; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use async_trait::async_trait; @@ -104,9 +104,6 @@ pub trait FileFormat: Send + Sync + fmt::Debug { ) -> Result> { not_impl_err!("Writer not implemented for this format") } - - /// Returns the FileType corresponding to this FileFormat - fn file_type(&self) -> FileType; } #[cfg(test)] diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index fa379eb5b445..8182ced6f228 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -47,7 +47,7 @@ use datafusion_common::config::TableParquetOptions; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::stats::Precision; use datafusion_common::{ - exec_err, internal_datafusion_err, not_impl_err, DataFusionError, FileType, + exec_err, internal_datafusion_err, not_impl_err, DataFusionError, }; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; @@ -286,10 +286,6 @@ impl FileFormat for ParquetFormat { order_requirements, )) as _) } - - fn file_type(&self) -> FileType { - FileType::PARQUET - } } fn summarize_min_max( diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 505748860388..1a82dac4658c 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -30,6 +30,7 @@ use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion_physical_plan::udaf::AggregateFunctionExpr; /// Optimizer that uses available statistics for aggregate functions #[derive(Default)] @@ -57,13 +58,9 @@ impl PhysicalOptimizerRule for AggregateStatistics { let mut projections = vec![]; for expr in partial_agg_exec.aggr_expr() { if let Some((non_null_rows, name)) = - take_optimizable_column_count(&**expr, &stats) + take_optimizable_column_and_table_count(&**expr, &stats) { projections.push((expressions::lit(non_null_rows), name.to_owned())); - } else if let Some((num_rows, name)) = - take_optimizable_table_count(&**expr, &stats) - { - projections.push((expressions::lit(num_rows), name.to_owned())); } else if let Some((min, name)) = take_optimizable_min(&**expr, &stats) { projections.push((expressions::lit(min), name.to_owned())); } else if let Some((max, name)) = take_optimizable_max(&**expr, &stats) { @@ -137,43 +134,48 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option> None } -/// If this agg_expr is a count that is exactly defined in the statistics, return it. -fn take_optimizable_table_count( +/// If this agg_expr is a count that can be exactly derived from the statistics, return it. +fn take_optimizable_column_and_table_count( agg_expr: &dyn AggregateExpr, stats: &Statistics, ) -> Option<(ScalarValue, String)> { - if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( - &stats.num_rows, - agg_expr.as_any().downcast_ref::(), - ) { - // TODO implementing Eq on PhysicalExpr would help a lot here - if casted_expr.expressions().len() == 1 { - if let Some(lit_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() - { - if lit_expr.value() == &COUNT_STAR_EXPANSION { - return Some(( - ScalarValue::Int64(Some(num_rows as i64)), - casted_expr.name().to_owned(), - )); + let col_stats = &stats.column_statistics; + if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { + if agg_expr.fun().name() == "COUNT" && !agg_expr.is_distinct() { + if let Precision::Exact(num_rows) = stats.num_rows { + let exprs = agg_expr.expressions(); + if exprs.len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = + exprs[0].as_any().downcast_ref::() + { + let current_val = &col_stats[col_expr.index()].null_count; + if let &Precision::Exact(val) = current_val { + return Some(( + ScalarValue::Int64(Some((num_rows - val) as i64)), + agg_expr.name().to_string(), + )); + } + } else if let Some(lit_expr) = + exprs[0].as_any().downcast_ref::() + { + if lit_expr.value() == &COUNT_STAR_EXPANSION { + return Some(( + ScalarValue::Int64(Some(num_rows as i64)), + agg_expr.name().to_string(), + )); + } + } } } } } - None -} - -/// If this agg_expr is a count that can be exactly derived from the statistics, return it. -fn take_optimizable_column_count( - agg_expr: &dyn AggregateExpr, - stats: &Statistics, -) -> Option<(ScalarValue, String)> { - let col_stats = &stats.column_statistics; - if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( + // TODO: Remove this after revmoing Builtin Count + else if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( &stats.num_rows, agg_expr.as_any().downcast_ref::(), ) { + // TODO implementing Eq on PhysicalExpr would help a lot here if casted_expr.expressions().len() == 1 { // TODO optimize with exprs other than Column if let Some(col_expr) = casted_expr.expressions()[0] @@ -187,6 +189,16 @@ fn take_optimizable_column_count( casted_expr.name().to_string(), )); } + } else if let Some(lit_expr) = casted_expr.expressions()[0] + .as_any() + .downcast_ref::() + { + if lit_expr.value() == &COUNT_STAR_EXPANSION { + return Some(( + ScalarValue::Int64(Some(num_rows as i64)), + casted_expr.name().to_owned(), + )); + } } } } diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index c07f2c5dcf24..cd84e911d381 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -3572,7 +3572,11 @@ pub(crate) mod tests { expr: col("c", &schema).unwrap(), options: SortOptions::default(), }]; - let alias = vec![("a".to_string(), "a".to_string())]; + let alias = vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "b".to_string()), + ("c".to_string(), "c".to_string()), + ]; let plan = sort_preserving_merge_exec( sort_key.clone(), sort_exec( @@ -3585,7 +3589,7 @@ pub(crate) mod tests { let expected = &[ "SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", // Since this projection is trivial, increasing parallelism is not beneficial - "ProjectionExec: expr=[a@0 as a]", + "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ]; assert_optimized!(expected, plan.clone(), true); diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 042a0198bfb5..135a59aa0353 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -39,8 +39,8 @@ use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use arrow_schema::Schema; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, JoinSide, JoinType}; +use datafusion_expr::sort_properties::SortProperties; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::sort_properties::SortProperties; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; /// The [`JoinSelection`] rule tries to modify a given plan so that it can @@ -561,7 +561,7 @@ fn hash_join_convert_symmetric_subrule( let name = schema.field(*index).name(); let col = Arc::new(Column::new(name, *index)) as _; // Check if the column is ordered. - equivalence.get_expr_ordering(col).data + equivalence.get_expr_properties(col).sort_properties != SortProperties::Unordered }, ) diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 08cbf68fa617..416985983dfe 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -112,11 +112,6 @@ impl PhysicalOptimizer { // Remove the ancillary output requirement operator since we are done with the planning // phase. Arc::new(OutputRequirements::new_remove_mode()), - // The PipelineChecker rule will reject non-runnable query plans that use - // pipeline-breaking operators on infinite input(s). The rule generates a - // diagnostic error message when this happens. It makes no changes to the - // given query plan; i.e. it only acts as a final gatekeeping rule. - Arc::new(PipelineChecker::new()), // The aggregation limiter will try to find situations where the accumulator count // is not tied to the cardinality, i.e. when the output of the aggregation is passed // into an `order by max(x) limit y`. In this case it will copy the limit value down @@ -129,6 +124,11 @@ impl PhysicalOptimizer { // are not present, the load of executors such as join or union will be // reduced by narrowing their input tables. Arc::new(ProjectionPushdown::new()), + // The PipelineChecker rule will reject non-runnable query plans that use + // pipeline-breaking operators on infinite input(s). The rule generates a + // diagnostic error message when this happens. It makes no changes to the + // given query plan; i.e. it only acts as a final gatekeeping rule. + Arc::new(PipelineChecker::new()), ]; Self::with_rules(rules) diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 0190f35cc97b..fe1290e40774 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -1376,7 +1376,6 @@ mod tests { )), ], DataType::Int32, - None, )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -1442,7 +1441,6 @@ mod tests { )), ], DataType::Int32, - None, )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 3))), @@ -1511,7 +1509,6 @@ mod tests { )), ], DataType::Int32, - None, )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -1577,7 +1574,6 @@ mod tests { )), ], DataType::Int32, - None, )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d_new", 3))), diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index d4a9a949fc41..406196a59146 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -252,31 +252,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { func_def, distinct, args, - filter, + filter: _, order_by, null_treatment: _, - }) => match func_def { - AggregateFunctionDefinition::BuiltIn(..) => create_function_physical_name( - func_def.name(), - *distinct, - args, - order_by.as_ref(), - ), - AggregateFunctionDefinition::UDF(fun) => { - // TODO: Add support for filter by in AggregateUDF - if filter.is_some() { - return exec_err!( - "aggregate expression with filter is not supported" - ); - } - - let names = args - .iter() - .map(|e| create_physical_name(e, false)) - .collect::>>()?; - Ok(format!("{}({})", fun.name(), names.join(","))) - } - }, + }) => create_function_physical_name( + func_def.name(), + *distinct, + args, + order_by.as_ref(), + ), Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Ok(format!( "ROLLUP ({})", @@ -1941,6 +1925,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( physical_input_schema, name, ignore_nulls, + *distinct, )?; (agg_expr, filter, physical_sort_exprs) } diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 2ffac6a775d7..7d155bb16c72 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -597,7 +597,7 @@ async fn test_fn_md5() -> Result<()> { #[tokio::test] #[cfg(feature = "unicode_expressions")] async fn test_fn_regexp_like() -> Result<()> { - let expr = regexp_like(col("a"), lit("[a-z]")); + let expr = regexp_like(col("a"), lit("[a-z]"), None); let expected = [ "+-----------------------------------+", @@ -612,13 +612,28 @@ async fn test_fn_regexp_like() -> Result<()> { assert_fn_batches!(expr, expected); + let expr = regexp_like(col("a"), lit("abc"), Some(lit("i"))); + + let expected = [ + "+-------------------------------------------+", + "| regexp_like(test.a,Utf8(\"abc\"),Utf8(\"i\")) |", + "+-------------------------------------------+", + "| true |", + "| true |", + "| false |", + "| true |", + "+-------------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + Ok(()) } #[tokio::test] #[cfg(feature = "unicode_expressions")] async fn test_fn_regexp_match() -> Result<()> { - let expr = regexp_match(col("a"), lit("[a-z]")); + let expr = regexp_match(col("a"), lit("[a-z]"), None); let expected = [ "+------------------------------------+", @@ -633,13 +648,28 @@ async fn test_fn_regexp_match() -> Result<()> { assert_fn_batches!(expr, expected); + let expr = regexp_match(col("a"), lit("[A-Z]"), Some(lit("i"))); + + let expected = [ + "+----------------------------------------------+", + "| regexp_match(test.a,Utf8(\"[A-Z]\"),Utf8(\"i\")) |", + "+----------------------------------------------+", + "| [a] |", + "| [a] |", + "| [C] |", + "| [A] |", + "+----------------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + Ok(()) } #[tokio::test] #[cfg(feature = "unicode_expressions")] async fn test_fn_regexp_replace() -> Result<()> { - let expr = regexp_replace(col("a"), lit("[a-z]"), lit("x"), lit("g")); + let expr = regexp_replace(col("a"), lit("[a-z]"), lit("x"), Some(lit("g"))); let expected = [ "+----------------------------------------------------------+", @@ -654,6 +684,21 @@ async fn test_fn_regexp_replace() -> Result<()> { assert_fn_batches!(expr, expected); + let expr = regexp_replace(col("a"), lit("[a-z]"), lit("x"), None); + + let expected = [ + "+------------------------------------------------+", + "| regexp_replace(test.a,Utf8(\"[a-z]\"),Utf8(\"x\")) |", + "+------------------------------------------------+", + "| xbcDEF |", + "| xbc123 |", + "| CBAxef |", + "| 123AxcDef |", + "+------------------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + Ok(()) } diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 6c9c3359ebf4..21ef8a7c2110 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -39,12 +39,12 @@ mod sp_repartition_fuzz_tests { config::SessionConfig, memory_pool::MemoryConsumer, SendableRecordBatchStream, }; use datafusion_physical_expr::{ + equivalence::{EquivalenceClass, EquivalenceProperties}, expressions::{col, Column}, - EquivalenceProperties, PhysicalExpr, PhysicalSortExpr, + PhysicalExpr, PhysicalSortExpr, }; use test_utils::add_empty_batches; - use datafusion_physical_expr::equivalence::EquivalenceClass; use itertools::izip; use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; @@ -78,7 +78,7 @@ mod sp_repartition_fuzz_tests { let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f); + eq_properties.add_equal_conditions(col_a, col_f)?; // Column e has constant value. eq_properties = eq_properties.add_constants([col_e.clone()]); diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 2514324a9541..fe0c408dc114 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -22,11 +22,10 @@ use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use arrow_schema::{Field, Schema}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ - create_window_expr, BoundedWindowAggExec, WindowAggExec, + create_window_expr, schema_add_window_field, BoundedWindowAggExec, WindowAggExec, }; use datafusion::physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; use datafusion::physical_plan::{collect, InputOrderMode}; @@ -40,7 +39,6 @@ use datafusion_expr::{ }; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use itertools::Itertools; use test_utils::add_empty_batches; use hashbrown::HashMap; @@ -276,7 +274,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { }; let extended_schema = - schema_add_window_fields(&args, &schema, &window_fn, fn_name)?; + schema_add_window_field(&args, &schema, &window_fn, fn_name)?; let window_expr = create_window_expr( &window_fn, @@ -683,7 +681,7 @@ async fn run_window_test( exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _; } - let extended_schema = schema_add_window_fields(&args, &schema, &window_fn, &fn_name)?; + let extended_schema = schema_add_window_field(&args, &schema, &window_fn, &fn_name)?; let usual_window_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( @@ -754,32 +752,6 @@ async fn run_window_test( Ok(()) } -// The planner has fully updated schema before calling the `create_window_expr` -// Replicate the same for this test -fn schema_add_window_fields( - args: &[Arc], - schema: &Arc, - window_fn: &WindowFunctionDefinition, - fn_name: &str, -) -> Result> { - let data_types = args - .iter() - .map(|e| e.clone().as_ref().data_type(schema)) - .collect::>>()?; - let window_expr_return_type = window_fn.return_type(&data_types)?; - let mut window_fields = schema - .fields() - .iter() - .map(|f| f.as_ref().clone()) - .collect_vec(); - window_fields.extend_from_slice(&[Field::new( - fn_name, - window_expr_return_type, - true, - )]); - Ok(Arc::new(Schema::new(window_fields))) -} - /// Return randomly sized record batches with: /// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns /// one random int32 column x diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 8f02fb30b013..d199f04ba781 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -725,7 +725,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator { panic!("accumulator shouldn't invoke"); } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index def9fcb4c61b..df41cab7bf02 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -15,26 +15,27 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::sync::Arc; + use arrow::compute::kernels::numeric::add; use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; +use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, cast::as_float64_array, - cast::as_int32_array, not_impl_err, plan_err, ExprSchema, Result, ScalarValue, + assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, internal_err, + not_impl_err, plan_err, DataFusionError, ExprSchema, Result, ScalarValue, }; -use datafusion_common::{assert_contains, exec_err, internal_err, DataFusionError}; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ Accumulator, ColumnarValue, CreateFunction, ExprSchemable, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; -use std::any::Any; -use std::sync::Arc; /// test that casting happens on udfs. /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and @@ -776,10 +777,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { fn aliases(&self) -> &[String] { &[] } - - fn monotonicity(&self) -> Result> { - Ok(None) - } } impl ScalarFunctionWrapper { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 36953742c1bf..a0bd0086aac7 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1892,16 +1892,8 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { order_by, null_treatment, }) => { - match func_def { - AggregateFunctionDefinition::BuiltIn(..) => { - write_function_name(w, func_def.name(), *distinct, args)?; - } - AggregateFunctionDefinition::UDF(fun) => { - write!(w, "{}(", fun.name())?; - write_names_join(w, args, ",")?; - write!(w, ")")?; - } - }; + write_function_name(w, func_def.name(), *distinct, args)?; + if let Some(fe) = filter { write!(w, " FILTER (WHERE {fe})")?; }; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1d976a12cc4f..64763a973687 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -23,6 +23,7 @@ use crate::expr::{ }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, + StateFieldsArgs, }; use crate::{ aggregate_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, @@ -690,12 +691,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { (self.accumulator)(acc_args) } - fn state_fields( - &self, - _name: &str, - _value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(self.state_fields.clone()) } } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 4e4d77924a9d..714cfa1af671 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -19,7 +19,7 @@ use crate::ColumnarValue; use crate::{Accumulator, Expr, PartitionEvaluator}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; use std::sync::Arc; @@ -41,11 +41,14 @@ pub type ReturnTypeFunction = /// [`AccumulatorArgs`] contains information about how an aggregate /// function was called, including the types of its arguments and any optional /// ordering expressions. +#[derive(Debug)] pub struct AccumulatorArgs<'a> { /// The return type of the aggregate function. pub data_type: &'a DataType, + /// The schema of the input arguments pub schema: &'a Schema, + /// Whether to ignore nulls. /// /// SQL allows the user to specify `IGNORE NULLS`, for example: @@ -66,22 +69,40 @@ pub struct AccumulatorArgs<'a> { /// /// If no `ORDER BY` is specified, `sort_exprs`` will be empty. pub sort_exprs: &'a [Expr], + + /// Whether the aggregate function is distinct. + /// + /// ```sql + /// SELECT COUNT(DISTINCT column1) FROM t; + /// ``` + pub is_distinct: bool, + + /// The input type of the aggregate function. + pub input_type: &'a DataType, + + /// The number of arguments the aggregate function takes. + pub args_num: usize, } -impl<'a> AccumulatorArgs<'a> { - pub fn new( - data_type: &'a DataType, - schema: &'a Schema, - ignore_nulls: bool, - sort_exprs: &'a [Expr], - ) -> Self { - Self { - data_type, - schema, - ignore_nulls, - sort_exprs, - } - } +/// [`StateFieldsArgs`] contains information about the fields that an +/// aggregate function's accumulator should have. Used for [`AggregateUDFImpl::state_fields`]. +/// +/// [`AggregateUDFImpl::state_fields`]: crate::udaf::AggregateUDFImpl::state_fields +pub struct StateFieldsArgs<'a> { + /// The name of the aggregate function. + pub name: &'a str, + + /// The input type of the aggregate function. + pub input_type: &'a DataType, + + /// The return type of the aggregate function. + pub return_type: &'a DataType, + + /// The ordering fields of the aggregate function. + pub ordering_fields: &'a [Field], + + /// Whether the aggregate function is distinct. + pub is_distinct: bool, } /// Factory that returns an accumulator for the given aggregate function. diff --git a/datafusion/expr/src/interval_arithmetic.rs b/datafusion/expr/src/interval_arithmetic.rs index ca91a8c9da00..c4890b97e748 100644 --- a/datafusion/expr/src/interval_arithmetic.rs +++ b/datafusion/expr/src/interval_arithmetic.rs @@ -273,19 +273,34 @@ impl Interval { unreachable!(); }; // Standardize boolean interval endpoints: - Self { + return Self { lower: ScalarValue::Boolean(Some(lower_bool.unwrap_or(false))), upper: ScalarValue::Boolean(Some(upper_bool.unwrap_or(true))), - } + }; } - // Standardize floating-point endpoints: - else if lower.data_type() == DataType::Float32 { - handle_float_intervals!(Float32, f32, lower, upper) - } else if lower.data_type() == DataType::Float64 { - handle_float_intervals!(Float64, f64, lower, upper) - } else { + match lower.data_type() { + // Standardize floating-point endpoints: + DataType::Float32 => handle_float_intervals!(Float32, f32, lower, upper), + DataType::Float64 => handle_float_intervals!(Float64, f64, lower, upper), + // Unsigned null values for lower bounds are set to zero: + DataType::UInt8 if lower.is_null() => Self { + lower: ScalarValue::UInt8(Some(0)), + upper, + }, + DataType::UInt16 if lower.is_null() => Self { + lower: ScalarValue::UInt16(Some(0)), + upper, + }, + DataType::UInt32 if lower.is_null() => Self { + lower: ScalarValue::UInt32(Some(0)), + upper, + }, + DataType::UInt64 if lower.is_null() => Self { + lower: ScalarValue::UInt64(Some(0)), + upper, + }, // Other data types do not require standardization: - Self { lower, upper } + _ => Self { lower, upper }, } } @@ -299,6 +314,12 @@ impl Interval { Self::try_new(ScalarValue::from(lower), ScalarValue::from(upper)) } + /// Creates a singleton zero interval if the datatype supported. + pub fn make_zero(data_type: &DataType) -> Result { + let zero_endpoint = ScalarValue::new_zero(data_type)?; + Ok(Self::new(zero_endpoint.clone(), zero_endpoint)) + } + /// Creates an unbounded interval from both sides if the datatype supported. pub fn make_unbounded(data_type: &DataType) -> Result { let unbounded_endpoint = ScalarValue::try_from(data_type)?; @@ -369,7 +390,7 @@ impl Interval { /// NOTE: This function only works with intervals of the same data type. /// Attempting to compare intervals of different data types will lead /// to an error. - pub(crate) fn gt>(&self, other: T) -> Result { + pub fn gt>(&self, other: T) -> Result { let rhs = other.borrow(); if self.data_type().ne(&rhs.data_type()) { internal_err!( @@ -402,7 +423,7 @@ impl Interval { /// NOTE: This function only works with intervals of the same data type. /// Attempting to compare intervals of different data types will lead /// to an error. - pub(crate) fn gt_eq>(&self, other: T) -> Result { + pub fn gt_eq>(&self, other: T) -> Result { let rhs = other.borrow(); if self.data_type().ne(&rhs.data_type()) { internal_err!( @@ -435,7 +456,7 @@ impl Interval { /// NOTE: This function only works with intervals of the same data type. /// Attempting to compare intervals of different data types will lead /// to an error. - pub(crate) fn lt>(&self, other: T) -> Result { + pub fn lt>(&self, other: T) -> Result { other.borrow().gt(self) } @@ -446,7 +467,7 @@ impl Interval { /// NOTE: This function only works with intervals of the same data type. /// Attempting to compare intervals of different data types will lead /// to an error. - pub(crate) fn lt_eq>(&self, other: T) -> Result { + pub fn lt_eq>(&self, other: T) -> Result { other.borrow().gt_eq(self) } @@ -457,7 +478,7 @@ impl Interval { /// NOTE: This function only works with intervals of the same data type. /// Attempting to compare intervals of different data types will lead /// to an error. - pub(crate) fn equal>(&self, other: T) -> Result { + pub fn equal>(&self, other: T) -> Result { let rhs = other.borrow(); if get_result_type(&self.data_type(), &Operator::Eq, &rhs.data_type()).is_err() { internal_err!( @@ -480,7 +501,7 @@ impl Interval { /// Compute the logical conjunction of this (boolean) interval with the /// given boolean interval. - pub(crate) fn and>(&self, other: T) -> Result { + pub fn and>(&self, other: T) -> Result { let rhs = other.borrow(); match (&self.lower, &self.upper, &rhs.lower, &rhs.upper) { ( @@ -501,8 +522,31 @@ impl Interval { } } + /// Compute the logical disjunction of this boolean interval with the + /// given boolean interval. + pub fn or>(&self, other: T) -> Result { + let rhs = other.borrow(); + match (&self.lower, &self.upper, &rhs.lower, &rhs.upper) { + ( + &ScalarValue::Boolean(Some(self_lower)), + &ScalarValue::Boolean(Some(self_upper)), + &ScalarValue::Boolean(Some(other_lower)), + &ScalarValue::Boolean(Some(other_upper)), + ) => { + let lower = self_lower || other_lower; + let upper = self_upper || other_upper; + + Ok(Self { + lower: ScalarValue::Boolean(Some(lower)), + upper: ScalarValue::Boolean(Some(upper)), + }) + } + _ => internal_err!("Incompatible data types for logical conjunction"), + } + } + /// Compute the logical negation of this (boolean) interval. - pub(crate) fn not(&self) -> Result { + pub fn not(&self) -> Result { if self.data_type().ne(&DataType::Boolean) { internal_err!("Cannot apply logical negation to a non-boolean interval") } else if self == &Self::CERTAINLY_TRUE { @@ -761,6 +805,18 @@ impl Interval { } .map(|result| result + 1) } + + /// Reflects an [`Interval`] around the point zero. + /// + /// This method computes the arithmetic negation of the interval, reflecting + /// it about the origin of the number line. This operation swaps and negates + /// the lower and upper bounds of the interval. + pub fn arithmetic_negate(self) -> Result { + Ok(Self { + lower: self.upper().clone().arithmetic_negate()?, + upper: self.lower().clone().arithmetic_negate()?, + }) + } } impl Display for Interval { @@ -1885,10 +1941,10 @@ mod tests { let unbounded_cases = vec![ (DataType::Boolean, Boolean(Some(false)), Boolean(Some(true))), - (DataType::UInt8, UInt8(None), UInt8(None)), - (DataType::UInt16, UInt16(None), UInt16(None)), - (DataType::UInt32, UInt32(None), UInt32(None)), - (DataType::UInt64, UInt64(None), UInt64(None)), + (DataType::UInt8, UInt8(Some(0)), UInt8(None)), + (DataType::UInt16, UInt16(Some(0)), UInt16(None)), + (DataType::UInt32, UInt32(Some(0)), UInt32(None)), + (DataType::UInt64, UInt64(Some(0)), UInt64(None)), (DataType::Int8, Int8(None), Int8(None)), (DataType::Int16, Int16(None), Int16(None)), (DataType::Int32, Int32(None), Int32(None)), @@ -1994,6 +2050,10 @@ mod tests { Interval::make(None, Some(1000_i64))?, Interval::make(Some(1000_i64), Some(1500_i64))?, ), + ( + Interval::make(Some(0_u8), Some(0_u8))?, + Interval::make::(None, None)?, + ), ( Interval::try_new( prev_value(ScalarValue::Float32(Some(0.0_f32))), @@ -2036,6 +2096,10 @@ mod tests { Interval::make(Some(-1000_i64), Some(1000_i64))?, Interval::make(None, Some(-1500_i64))?, ), + ( + Interval::make::(None, None)?, + Interval::make(Some(0_u64), Some(0_u64))?, + ), ( Interval::make(Some(0.0_f32), Some(0.0_f32))?, Interval::make(Some(0.0_f32), Some(0.0_f32))?, diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index e2b68388abb9..bac2f9c14541 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -50,6 +50,7 @@ pub mod groups_accumulator; pub mod interval_arithmetic; pub mod logical_plan; pub mod simplify; +pub mod sort_properties; pub mod tree_node; pub mod type_coercion; pub mod utils; @@ -77,8 +78,7 @@ pub use logical_plan::*; pub use operator::Operator; pub use partition_evaluator::PartitionEvaluator; pub use signature::{ - ArrayFunctionSignature, FuncMonotonicity, Signature, TypeSignature, Volatility, - TIMEZONE_WILDCARD, + ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{AggregateUDF, AggregateUDFImpl}; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ddf075c2c27b..4872e5acda5e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2407,6 +2407,16 @@ pub enum Distinct { On(DistinctOn), } +impl Distinct { + /// return a reference to the nodes input + pub fn input(&self) -> &Arc { + match self { + Distinct::All(input) => input, + Distinct::On(DistinctOn { input, .. }) => input, + } + } +} + /// Removes duplicate rows from the input #[derive(Clone, PartialEq, Eq, Hash)] pub struct DistinctOn { diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 5d925c8605ee..63b030f0b748 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -343,14 +343,6 @@ impl Signature { } } -/// Monotonicity of the `ScalarFunctionExpr` with respect to its arguments. -/// Each element of this vector corresponds to an argument and indicates whether -/// the function's behavior is monotonic, or non-monotonic/unknown for that argument, namely: -/// - `None` signifies unknown monotonicity or non-monotonicity. -/// - `Some(true)` indicates that the function is monotonically increasing w.r.t. the argument in question. -/// - Some(false) indicates that the function is monotonically decreasing w.r.t. the argument in question. -pub type FuncMonotonicity = Vec>; - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr-common/src/sort_properties.rs b/datafusion/expr/src/sort_properties.rs similarity index 77% rename from datafusion/physical-expr-common/src/sort_properties.rs rename to datafusion/expr/src/sort_properties.rs index 47a5d5ba5e3b..7778be2ecf0d 100644 --- a/datafusion/physical-expr-common/src/sort_properties.rs +++ b/datafusion/expr/src/sort_properties.rs @@ -17,9 +17,10 @@ use std::ops::Neg; -use arrow::compute::SortOptions; +use crate::interval_arithmetic::Interval; -use crate::tree_node::ExprContext; +use arrow::compute::SortOptions; +use arrow::datatypes::DataType; /// To propagate [`SortOptions`] across the `PhysicalExpr`, it is insufficient /// to simply use `Option`: There must be a differentiation between @@ -120,29 +121,39 @@ impl SortProperties { impl Neg for SortProperties { type Output = Self; - fn neg(self) -> Self::Output { - match self { - SortProperties::Ordered(SortOptions { - descending, - nulls_first, - }) => SortProperties::Ordered(SortOptions { - descending: !descending, - nulls_first, - }), - SortProperties::Singleton => SortProperties::Singleton, - SortProperties::Unordered => SortProperties::Unordered, + fn neg(mut self) -> Self::Output { + if let SortProperties::Ordered(SortOptions { descending, .. }) = &mut self { + *descending = !*descending; } + self } } -/// The `ExprOrdering` struct is designed to aid in the determination of ordering (represented -/// by [`SortProperties`]) for a given `PhysicalExpr`. When analyzing the orderings -/// of a `PhysicalExpr`, the process begins by assigning the ordering of its leaf nodes. -/// By propagating these leaf node orderings upwards in the expression tree, the overall -/// ordering of the entire `PhysicalExpr` can be derived. -/// -/// This struct holds the necessary state information for each expression in the `PhysicalExpr`. -/// It encapsulates the orderings (`data`) associated with the expression (`expr`), and -/// orderings of the children expressions (`children`). The [`ExprOrdering`] of a parent -/// expression is determined based on the [`ExprOrdering`] states of its children expressions. -pub type ExprOrdering = ExprContext; +/// Represents the properties of a `PhysicalExpr`, including its sorting and range attributes. +#[derive(Debug, Clone)] +pub struct ExprProperties { + pub sort_properties: SortProperties, + pub range: Interval, +} + +impl ExprProperties { + /// Creates a new `ExprProperties` instance with unknown sort properties and unknown range. + pub fn new_unknown() -> Self { + Self { + sort_properties: SortProperties::default(), + range: Interval::make_unbounded(&DataType::Null).unwrap(), + } + } + + /// Sets the sorting properties of the expression and returns the modified instance. + pub fn with_order(mut self, order: SortProperties) -> Self { + self.sort_properties = order; + self + } + + /// Sets the range of the expression and returns the modified instance. + pub fn with_range(mut self, range: Interval) -> Self { + self.range = range; + self + } +} diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 95121d78e7aa..4fd8d51679f0 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,7 +17,9 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions -use crate::function::{AccumulatorArgs, AggregateFunctionSimplification}; +use crate::function::{ + AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, +}; use crate::groups_accumulator::GroupsAccumulator; use crate::utils::format_state_name; use crate::{Accumulator, Expr}; @@ -177,18 +179,13 @@ impl AggregateUDF { /// for more details. /// /// This is used to support multi-phase aggregations - pub fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - ) -> Result> { - self.inner.state_fields(name, value_type, ordering_fields) + pub fn state_fields(&self, args: StateFieldsArgs) -> Result> { + self.inner.state_fields(args) } /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details. - pub fn groups_accumulator_supported(&self) -> bool { - self.inner.groups_accumulator_supported() + pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + self.inner.groups_accumulator_supported(args) } /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details. @@ -232,7 +229,7 @@ where /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr}; -/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::AccumulatorArgs}; +/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}}; /// # use arrow::datatypes::Schema; /// # use arrow::datatypes::Field; /// #[derive(Debug, Clone)] @@ -261,9 +258,9 @@ where /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. /// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { unimplemented!() } -/// fn state_fields(&self, _name: &str, value_type: DataType, _ordering_fields: Vec) -> Result> { +/// fn state_fields(&self, args: StateFieldsArgs) -> Result> { /// Ok(vec![ -/// Field::new("value", value_type, true), +/// Field::new("value", args.return_type.clone(), true), /// Field::new("ordering", DataType::UInt32, true) /// ]) /// } @@ -319,19 +316,17 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// The name of the fields must be unique within the query and thus should /// be derived from `name`. See [`format_state_name`] for a utility function /// to generate a unique name. - fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - ) -> Result> { - let value_fields = vec![Field::new( - format_state_name(name, "value"), - value_type, + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let fields = vec![Field::new( + format_state_name(args.name, "value"), + args.return_type.clone(), true, )]; - Ok(value_fields.into_iter().chain(ordering_fields).collect()) + Ok(fields + .into_iter() + .chain(args.ordering_fields.to_vec()) + .collect()) } /// If the aggregate expression has a specialized @@ -344,7 +339,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// `Self::accumulator` for certain queries, such as when this aggregate is /// used as a window function or when there no GROUP BY columns in the /// query. - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { false } @@ -389,6 +384,20 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn simplify(&self) -> Option { None } + + /// Returns the reverse expression of the aggregate function. + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::NotSupported + } +} + +pub enum ReversedUDAF { + /// The expression is the same as the original expression, like SUM, COUNT + Identical, + /// The expression does not support reverse calculation, like ArrayAgg + NotSupported, + /// The expression is different from the original expression + Reversed(Arc), } /// AggregateUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index fadea26e7f4e..921d13ab3583 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,19 +17,20 @@ //! [`ScalarUDF`]: Scalar User Defined Functions +use std::any::Any; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; + use crate::expr::create_name; +use crate::interval_arithmetic::Interval; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; +use crate::sort_properties::{ExprProperties, SortProperties}; use crate::{ - ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction, - ScalarFunctionImplementation, Signature, + ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature, }; + use arrow::datatypes::DataType; use datafusion_common::{not_impl_err, ExprSchema, Result}; -use std::any::Any; -use std::fmt; -use std::fmt::Debug; -use std::fmt::Formatter; -use std::sync::Arc; /// Logical representation of a Scalar User Defined Function. /// @@ -202,18 +203,63 @@ impl ScalarUDF { Arc::new(move |args| captured.invoke(args)) } - /// This function specifies monotonicity behaviors for User defined scalar functions. - /// - /// See [`ScalarUDFImpl::monotonicity`] for more details. - pub fn monotonicity(&self) -> Result> { - self.inner.monotonicity() - } - /// Get the circuits of inner implementation pub fn short_circuits(&self) -> bool { self.inner.short_circuits() } + /// Computes the output interval for a [`ScalarUDF`], given the input + /// intervals. + /// + /// # Parameters + /// + /// * `inputs` are the intervals for the inputs (children) of this function. + /// + /// # Example + /// + /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`, + /// then the output interval would be `[0, 3]`. + pub fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { + self.inner.evaluate_bounds(inputs) + } + + /// Updates bounds for child expressions, given a known interval for this + /// function. This is used to propagate constraints down through an expression + /// tree. + /// + /// # Parameters + /// + /// * `interval` is the currently known interval for this function. + /// * `inputs` are the current intervals for the inputs (children) of this function. + /// + /// # Returns + /// + /// A `Vec` of new intervals for the children, in order. + /// + /// If constraint propagation reveals an infeasibility for any child, returns + /// [`None`]. If none of the children intervals change as a result of + /// propagation, may return an empty vector instead of cloning `children`. + /// This is the default (and conservative) return value. + /// + /// # Example + /// + /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the + /// input `a` is given as `[-7, -6]`, then propagation would would return + /// `[-5, 5]`. + pub fn propagate_constraints( + &self, + interval: &Interval, + inputs: &[&Interval], + ) -> Result>> { + self.inner.propagate_constraints(interval, inputs) + } + + /// Calculates the [`SortProperties`] of this function based on its + /// children's properties. + pub fn monotonicity(&self, inputs: &[ExprProperties]) -> Result { + self.inner.monotonicity(inputs) + } + /// See [`ScalarUDFImpl::coerce_types`] for more details. pub fn coerce_types(&self, arg_types: &[DataType]) -> Result> { self.inner.coerce_types(arg_types) @@ -387,11 +433,6 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { &[] } - /// This function specifies monotonicity behaviors for User defined scalar functions. - fn monotonicity(&self) -> Result> { - Ok(None) - } - /// Optionally apply per-UDF simplification / rewrite rules. /// /// This can be used to apply function specific simplification rules during @@ -426,6 +467,59 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { false } + /// Computes the output interval for a [`ScalarUDFImpl`], given the input + /// intervals. + /// + /// # Parameters + /// + /// * `children` are the intervals for the children (inputs) of this function. + /// + /// # Example + /// + /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`, + /// then the output interval would be `[0, 3]`. + fn evaluate_bounds(&self, _input: &[&Interval]) -> Result { + // We cannot assume the input datatype is the same of output type. + Interval::make_unbounded(&DataType::Null) + } + + /// Updates bounds for child expressions, given a known interval for this + /// function. This is used to propagate constraints down through an expression + /// tree. + /// + /// # Parameters + /// + /// * `interval` is the currently known interval for this function. + /// * `inputs` are the current intervals for the inputs (children) of this function. + /// + /// # Returns + /// + /// A `Vec` of new intervals for the children, in order. + /// + /// If constraint propagation reveals an infeasibility for any child, returns + /// [`None`]. If none of the children intervals change as a result of + /// propagation, may return an empty vector instead of cloning `children`. + /// This is the default (and conservative) return value. + /// + /// # Example + /// + /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the + /// input `a` is given as `[-7, -6]`, then propagation would would return + /// `[-5, 5]`. + fn propagate_constraints( + &self, + _interval: &Interval, + _inputs: &[&Interval], + ) -> Result>> { + Ok(Some(vec![])) + } + + /// Calculates the [`SortProperties`] of this function based on its + /// children's properties. + fn monotonicity(&self, _inputs: &[ExprProperties]) -> Result { + Ok(SortProperties::Unordered) + } + /// Coerce arguments of a function call to types that the function can evaluate. /// /// This function is only called if [`ScalarUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index 1210e1529dbb..6f03b256fd9f 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -30,8 +30,10 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::{ - function::AccumulatorArgs, type_coercion::aggregates::NUMERICS, - utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility, + function::{AccumulatorArgs, StateFieldsArgs}, + type_coercion::aggregates::NUMERICS, + utils::format_state_name, + Accumulator, AggregateUDFImpl, Signature, Volatility, }; use datafusion_physical_expr_common::aggregate::stats::StatsType; @@ -101,12 +103,8 @@ impl AggregateUDFImpl for CovarianceSample { Ok(DataType::Float64) } - fn state_fields( - &self, - name: &str, - _value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean1"), DataType::Float64, true), @@ -176,12 +174,8 @@ impl AggregateUDFImpl for CovariancePopulation { Ok(DataType::Float64) } - fn state_fields( - &self, - name: &str, - _value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean1"), DataType::Float64, true), diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index e3b685e90376..5d3d48344014 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -24,7 +24,7 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -147,18 +147,13 @@ impl AggregateUDFImpl for FirstValue { .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) } - fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new( - format_state_name(name, "first_value"), - value_type, + format_state_name(args.name, "first_value"), + args.return_type.clone(), true, )]; - fields.extend(ordering_fields); + fields.extend(args.ordering_fields.to_vec()); fields.push(Field::new("is_set", DataType::Boolean, true)); Ok(fields) } diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index 27fc623a182b..6c3348d6c1d6 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -15,24 +15,55 @@ // specific language governing permissions and limitations // under the License. +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + macro_rules! make_udaf_expr_and_func { ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN( + $($arg: datafusion_expr::Expr,)* + ) -> datafusion_expr::Expr { + datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + $AGGREGATE_UDF_FN(), + vec![$($arg),*], + false, + None, + None, + None, + )) + } + create_func!($UDAF, $AGGREGATE_UDF_FN); + }; + ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $distinct:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { // "fluent expr_fn" style function #[doc = $DOC] pub fn $EXPR_FN( $($arg: datafusion_expr::Expr,)* distinct: bool, - filter: Option>, - order_by: Option>, - null_treatment: Option ) -> datafusion_expr::Expr { datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( $AGGREGATE_UDF_FN(), vec![$($arg),*], distinct, - filter, - order_by, - null_treatment, + None, + None, + None )) } create_func!($UDAF, $AGGREGATE_UDF_FN); diff --git a/datafusion/functions-array/src/array_has.rs b/datafusion/functions-array/src/array_has.rs index 43d6046f4f82..466a913e35e5 100644 --- a/datafusion/functions-array/src/array_has.rs +++ b/datafusion/functions-array/src/array_has.rs @@ -22,8 +22,6 @@ use arrow::datatypes::DataType; use arrow::row::{RowConverter, SortField}; use datafusion_common::cast::as_generic_list_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use itertools::Itertools; diff --git a/datafusion/functions-array/src/cardinality.rs b/datafusion/functions-array/src/cardinality.rs index d6f2456313bc..d17965b795ad 100644 --- a/datafusion/functions-array/src/cardinality.rs +++ b/datafusion/functions-array/src/cardinality.rs @@ -24,8 +24,7 @@ use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/functions-array/src/concat.rs b/datafusion/functions-array/src/concat.rs index a6fed84fa765..d49cef66742f 100644 --- a/datafusion/functions-array/src/concat.rs +++ b/datafusion/functions-array/src/concat.rs @@ -27,8 +27,6 @@ use datafusion_common::Result; use datafusion_common::{ cast::as_generic_list_array, exec_err, not_impl_err, plan_err, utils::list_ndims, }; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; use datafusion_expr::{ type_coercion::binary::get_wider_type, ColumnarValue, ScalarUDFImpl, Signature, Volatility, diff --git a/datafusion/functions-array/src/dimension.rs b/datafusion/functions-array/src/dimension.rs index 1dc6520f1bc7..0c65da283bbb 100644 --- a/datafusion/functions-array/src/dimension.rs +++ b/datafusion/functions-array/src/dimension.rs @@ -29,8 +29,7 @@ use datafusion_common::{exec_err, plan_err, Result}; use crate::utils::{compute_array_dims, make_scalar_function}; use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; use arrow_schema::Field; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::sync::Arc; make_udf_expr_and_func!( diff --git a/datafusion/functions-array/src/empty.rs b/datafusion/functions-array/src/empty.rs index 9fe2c870496b..c5fe74480fb5 100644 --- a/datafusion/functions-array/src/empty.rs +++ b/datafusion/functions-array/src/empty.rs @@ -23,8 +23,7 @@ use arrow_schema::DataType; use arrow_schema::DataType::{Boolean, FixedSizeList, LargeList, List}; use datafusion_common::cast::{as_generic_list_array, as_null_array}; use datafusion_common::{exec_err, plan_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/functions-array/src/except.rs b/datafusion/functions-array/src/except.rs index a56bab1e0611..453b4f77119d 100644 --- a/datafusion/functions-array/src/except.rs +++ b/datafusion/functions-array/src/except.rs @@ -24,8 +24,6 @@ use arrow_array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow_buffer::OffsetBuffer; use arrow_schema::{DataType, FieldRef}; use datafusion_common::{exec_err, internal_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::collections::HashSet; diff --git a/datafusion/functions-array/src/extract.rs b/datafusion/functions-array/src/extract.rs index 842f4ec1b839..152e5f3c4b13 100644 --- a/datafusion/functions-array/src/extract.rs +++ b/datafusion/functions-array/src/extract.rs @@ -35,7 +35,6 @@ use datafusion_common::cast::as_list_array; use datafusion_common::{ exec_err, internal_datafusion_err, plan_err, DataFusionError, Result, }; -use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -419,19 +418,16 @@ where if let (Some(from), Some(to)) = (from_index, to_index) { let stride = stride.map(|s| s.value(row_index)); - // array_slice with stride in duckdb, return empty array if stride is not supported and from > to. - if stride.is_none() && from > to { - // return empty array - offsets.push(offsets[row_index]); - continue; - } + // Default stride is 1 if not provided let stride = stride.unwrap_or(1); if stride.is_zero() { return exec_err!( "array_slice got invalid stride: {:?}, it cannot be 0", stride ); - } else if from <= to && stride.is_negative() { + } else if (from <= to && stride.is_negative()) + || (from > to && stride.is_positive()) + { // return empty array offsets.push(offsets[row_index]); continue; diff --git a/datafusion/functions-array/src/flatten.rs b/datafusion/functions-array/src/flatten.rs index 294d41ada7c3..41762157fc6a 100644 --- a/datafusion/functions-array/src/flatten.rs +++ b/datafusion/functions-array/src/flatten.rs @@ -26,8 +26,7 @@ use datafusion_common::cast::{ as_generic_list_array, as_large_list_array, as_list_array, }; use datafusion_common::{exec_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/functions-array/src/length.rs b/datafusion/functions-array/src/length.rs index 9cdcaddf8dff..ed04c52584c0 100644 --- a/datafusion/functions-array/src/length.rs +++ b/datafusion/functions-array/src/length.rs @@ -27,8 +27,7 @@ use core::any::type_name; use datafusion_common::cast::{as_generic_list_array, as_int64_array}; use datafusion_common::DataFusionError; use datafusion_common::{exec_err, plan_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/functions-array/src/macros.rs b/datafusion/functions-array/src/macros.rs index 4e00aa39bd84..a6e0c2ee62be 100644 --- a/datafusion/functions-array/src/macros.rs +++ b/datafusion/functions-array/src/macros.rs @@ -48,8 +48,8 @@ macro_rules! make_udf_expr_and_func { paste::paste! { // "fluent expr_fn" style function #[doc = $DOC] - pub fn $EXPR_FN($($arg: Expr),*) -> Expr { - Expr::ScalarFunction(ScalarFunction::new_udf( + pub fn $EXPR_FN($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { + datafusion_expr::Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction::new_udf( $SCALAR_UDF_FN(), vec![$($arg),*], )) @@ -61,8 +61,8 @@ macro_rules! make_udf_expr_and_func { paste::paste! { // "fluent expr_fn" style function #[doc = $DOC] - pub fn $EXPR_FN(arg: Vec) -> Expr { - Expr::ScalarFunction(ScalarFunction::new_udf( + pub fn $EXPR_FN(arg: Vec) -> datafusion_expr::Expr { + datafusion_expr::Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction::new_udf( $SCALAR_UDF_FN(), arg, )) diff --git a/datafusion/functions-array/src/make_array.rs b/datafusion/functions-array/src/make_array.rs index 4723464dfaf2..a433f2e49326 100644 --- a/datafusion/functions-array/src/make_array.rs +++ b/datafusion/functions-array/src/make_array.rs @@ -28,10 +28,9 @@ use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, Field}; use datafusion_common::internal_err; use datafusion_common::{plan_err, utils::array_into_list_array, Result}; -use datafusion_expr::expr::ScalarFunction; use datafusion_expr::type_coercion::binary::comparison_coercion; +use datafusion_expr::TypeSignature; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use datafusion_expr::{Expr, TypeSignature}; use crate::utils::make_scalar_function; diff --git a/datafusion/functions-array/src/position.rs b/datafusion/functions-array/src/position.rs index efdb7dff0ce6..0002d5c40b3e 100644 --- a/datafusion/functions-array/src/position.rs +++ b/datafusion/functions-array/src/position.rs @@ -19,8 +19,6 @@ use arrow_schema::DataType::{LargeList, List, UInt64}; use arrow_schema::{DataType, Field}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/functions-array/src/range.rs b/datafusion/functions-array/src/range.rs index 9a9829f96100..881e86d63ab8 100644 --- a/datafusion/functions-array/src/range.rs +++ b/datafusion/functions-array/src/range.rs @@ -27,8 +27,6 @@ use arrow_schema::DataType::{Date32, Int64, Interval, List}; use arrow_schema::IntervalUnit::MonthDayNano; use datafusion_common::cast::{as_date32_array, as_int64_array, as_interval_mdn_array}; use datafusion_common::{exec_err, not_impl_datafusion_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; use datafusion_expr::{ ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; diff --git a/datafusion/functions-array/src/remove.rs b/datafusion/functions-array/src/remove.rs index 7645c1a57573..8c408f1650e9 100644 --- a/datafusion/functions-array/src/remove.rs +++ b/datafusion/functions-array/src/remove.rs @@ -27,8 +27,7 @@ use arrow_buffer::OffsetBuffer; use arrow_schema::{DataType, Field}; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/functions-array/src/repeat.rs b/datafusion/functions-array/src/repeat.rs index df623c114818..78bcde9eaba7 100644 --- a/datafusion/functions-array/src/repeat.rs +++ b/datafusion/functions-array/src/repeat.rs @@ -29,8 +29,7 @@ use arrow_schema::DataType::{LargeList, List}; use arrow_schema::{DataType, Field}; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/functions-array/src/replace.rs b/datafusion/functions-array/src/replace.rs index 7cea4945836e..8ac32538ad4f 100644 --- a/datafusion/functions-array/src/replace.rs +++ b/datafusion/functions-array/src/replace.rs @@ -27,8 +27,6 @@ use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; use arrow_schema::Field; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use crate::utils::compare_element_to_list; diff --git a/datafusion/functions-array/src/resize.rs b/datafusion/functions-array/src/resize.rs index 63f28c9afa77..7028bd1c33cc 100644 --- a/datafusion/functions-array/src/resize.rs +++ b/datafusion/functions-array/src/resize.rs @@ -25,8 +25,7 @@ use arrow_schema::DataType::{FixedSizeList, LargeList, List}; use arrow_schema::{DataType, FieldRef}; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/functions-array/src/reverse.rs b/datafusion/functions-array/src/reverse.rs index 3076013899ef..c9988524cabd 100644 --- a/datafusion/functions-array/src/reverse.rs +++ b/datafusion/functions-array/src/reverse.rs @@ -25,8 +25,7 @@ use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, FieldRef}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/functions-array/src/set_ops.rs b/datafusion/functions-array/src/set_ops.rs index 40676b7cdcb8..9032a745ef7a 100644 --- a/datafusion/functions-array/src/set_ops.rs +++ b/datafusion/functions-array/src/set_ops.rs @@ -27,8 +27,6 @@ use arrow::row::{RowConverter, SortField}; use arrow_schema::DataType::{FixedSizeList, LargeList, List, Null}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, internal_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use itertools::Itertools; use std::any::Any; diff --git a/datafusion/functions-array/src/sort.rs b/datafusion/functions-array/src/sort.rs index 16f271ef10ff..2a554bf3d9da 100644 --- a/datafusion/functions-array/src/sort.rs +++ b/datafusion/functions-array/src/sort.rs @@ -25,8 +25,7 @@ use arrow_schema::DataType::{FixedSizeList, LargeList, List}; use arrow_schema::{DataType, Field, SortOptions}; use datafusion_common::cast::{as_list_array, as_string_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/functions-array/src/string.rs b/datafusion/functions-array/src/string.rs index 4122ddbd45eb..e14e315752b4 100644 --- a/datafusion/functions-array/src/string.rs +++ b/datafusion/functions-array/src/string.rs @@ -24,8 +24,7 @@ use arrow::array::{ UInt8Array, }; use arrow::datatypes::{DataType, Field}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{Expr, TypeSignature}; +use datafusion_expr::TypeSignature; use datafusion_common::{plan_err, DataFusionError, Result}; diff --git a/datafusion/functions-array/src/udf.rs b/datafusion/functions-array/src/udf.rs deleted file mode 100644 index c723fbb42cfc..000000000000 --- a/datafusion/functions-array/src/udf.rs +++ /dev/null @@ -1,688 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`ScalarUDFImpl`] definitions for array functions. - -use arrow::array::{NullArray, StringArray}; -use arrow::datatypes::DataType; -use arrow::datatypes::Field; -use arrow::datatypes::IntervalUnit::MonthDayNano; -use arrow_schema::DataType::{LargeUtf8, List, Utf8}; -use datafusion_common::exec_err; -use datafusion_common::plan_err; -use datafusion_common::Result; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; -use datafusion_expr::TypeSignature; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; - -// Create static instances of ScalarUDFs for each function -make_udf_function!(ArrayToString, - array_to_string, - array delimiter, // arg name - "converts each element to its text representation.", // doc - array_to_string_udf // internal function name -); -#[derive(Debug)] -pub struct ArrayToString { - signature: Signature, - aliases: Vec, -} - -impl ArrayToString { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![ - String::from("array_to_string"), - String::from("list_to_string"), - String::from("array_join"), - String::from("list_join"), - ], - } - } -} - -impl ScalarUDFImpl for ArrayToString { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_to_string" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Utf8, - _ => { - return plan_err!("The array_to_string function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_to_string(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!(StringToArray, - string_to_array, - string delimiter null_string, // arg name - "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`", // doc - string_to_array_udf // internal function name -); -#[derive(Debug)] -pub struct StringToArray { - signature: Signature, - aliases: Vec, -} - -impl StringToArray { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![ - String::from("string_to_array"), - String::from("string_to_list"), - ], - } - } -} - -impl ScalarUDFImpl for StringToArray { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "string_to_array" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - Utf8 | LargeUtf8 => { - List(Arc::new(Field::new("item", arg_types[0].clone(), true))) - } - _ => { - return plan_err!( - "The string_to_array function can only accept Utf8 or LargeUtf8." - ); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let mut args = ColumnarValue::values_to_arrays(args)?; - // Case: delimiter is NULL, needs to be handled as well. - if args[1].as_any().is::() { - args[1] = Arc::new(StringArray::new_null(args[1].len())); - }; - - match args[0].data_type() { - Utf8 => { - crate::kernels::string_to_array::(&args).map(ColumnarValue::Array) - } - LargeUtf8 => { - crate::kernels::string_to_array::(&args).map(ColumnarValue::Array) - } - other => { - exec_err!("unsupported type for string_to_array function as {other}") - } - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayDims, - array_dims, - array, - "returns an array of the array's dimensions.", - array_dims_udf -); - -#[derive(Debug)] -pub struct ArrayDims { - signature: Signature, - aliases: Vec, -} - -impl ArrayDims { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec!["array_dims".to_string(), "list_dims".to_string()], - } - } -} - -impl ScalarUDFImpl for ArrayDims { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_dims" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => { - List(Arc::new(Field::new("item", UInt64, true))) - } - _ => { - return plan_err!("The array_dims function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_dims(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArraySort, - array_sort, - array desc null_first, - "returns sorted array.", - array_sort_udf -); - -#[derive(Debug)] -pub struct ArraySort { - signature: Signature, - aliases: Vec, -} - -impl ArraySort { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec!["array_sort".to_string(), "list_sort".to_string()], - } - } -} - -impl ScalarUDFImpl for ArraySort { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_sort" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - LargeList(field) => Ok(LargeList(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), - } - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_sort(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - Cardinality, - cardinality, - array, - "returns the total number of elements in the array.", - cardinality_udf -); - -impl Cardinality { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("cardinality")], - } - } -} - -#[derive(Debug)] -pub struct Cardinality { - signature: Signature, - aliases: Vec, -} -impl ScalarUDFImpl for Cardinality { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "cardinality" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::cardinality(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayNdims, - array_ndims, - array, - "returns the number of dimensions of the array.", - array_ndims_udf -); - -#[derive(Debug)] -pub struct ArrayNdims { - signature: Signature, - aliases: Vec, -} -impl ArrayNdims { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("array_ndims"), String::from("list_ndims")], - } - } -} - -impl ScalarUDFImpl for ArrayNdims { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_ndims" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The array_ndims function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_ndims(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayEmpty, - array_empty, - array, - "returns true for an empty array or false for a non-empty array.", - array_empty_udf -); - -#[derive(Debug)] -pub struct ArrayEmpty { - signature: Signature, - aliases: Vec, -} -impl ArrayEmpty { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("empty")], - } - } -} - -impl ScalarUDFImpl for ArrayEmpty { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "empty" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Boolean, - _ => { - return plan_err!("The array_empty function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_empty(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayRepeat, - array_repeat, - element count, // arg name - "returns an array containing element `count` times.", // doc - array_repeat_udf // internal function name -); -#[derive(Debug)] -pub struct ArrayRepeat { - signature: Signature, - aliases: Vec, -} - -impl ArrayRepeat { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![String::from("array_repeat"), String::from("list_repeat")], - } - } -} - -impl ScalarUDFImpl for ArrayRepeat { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_repeat" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(List(Arc::new(Field::new( - "item", - arg_types[0].clone(), - true, - )))) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_repeat(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayLength, - array_length, - array, - "returns the length of the array dimension.", - array_length_udf -); - -#[derive(Debug)] -pub struct ArrayLength { - signature: Signature, - aliases: Vec, -} -impl ArrayLength { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![String::from("array_length"), String::from("list_length")], - } - } -} - -impl ScalarUDFImpl for ArrayLength { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_length" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The array_length function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_length(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - Flatten, - flatten, - array, - "flattens an array of arrays into a single array.", - flatten_udf -); - -#[derive(Debug)] -pub struct Flatten { - signature: Signature, - aliases: Vec, -} -impl Flatten { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("flatten")], - } - } -} - -impl ScalarUDFImpl for Flatten { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "flatten" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - fn get_base_type(data_type: &DataType) -> Result { - match data_type { - List(field) | FixedSizeList(field, _) - if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) => - { - get_base_type(field.data_type()) - } - LargeList(field) if matches!(field.data_type(), LargeList(_)) => { - get_base_type(field.data_type()) - } - Null | List(_) | LargeList(_) => Ok(data_type.to_owned()), - FixedSizeList(field, _) => Ok(List(field.clone())), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), - } - } - - let data_type = get_base_type(&arg_types[0])?; - Ok(data_type) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::flatten(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayDistinct, - array_distinct, - array, - "return distinct values from the array after removing duplicates.", - array_distinct_udf -); - -#[derive(Debug)] -pub struct ArrayDistinct { - signature: Signature, - aliases: Vec, -} - -impl crate::udf::ArrayDistinct { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec!["array_distinct".to_string(), "list_distinct".to_string()], - } - } -} - -impl ScalarUDFImpl for crate::udf::ArrayDistinct { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_distinct" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - LargeList(field) => Ok(LargeList(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), - } - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_distinct(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} diff --git a/datafusion/functions/src/datetime/common.rs b/datafusion/functions/src/datetime/common.rs index f0689ffd64e9..4f48ab188403 100644 --- a/datafusion/functions/src/datetime/common.rs +++ b/datafusion/functions/src/datetime/common.rs @@ -93,7 +93,9 @@ pub(crate) fn string_to_datetime_formatted( if let Err(e) = &dt { // no timezone or other failure, try without a timezone - let ndt = parsed.to_naive_datetime_with_offset(0); + let ndt = parsed + .to_naive_datetime_with_offset(0) + .or_else(|_| parsed.to_naive_date().map(|nd| nd.into())); if let Err(e) = &ndt { return Err(err(&e.to_string())); } diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index da1797cdae81..51f5c09a0665 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -29,16 +29,17 @@ use arrow::datatypes::DataType::{Null, Timestamp, Utf8}; use arrow::datatypes::IntervalUnit::{DayTime, MonthDayNano}; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{DataType, TimeUnit}; -use chrono::{DateTime, Datelike, Duration, Months, TimeDelta, Utc}; use datafusion_common::cast::as_primitive_array; use datafusion_common::{exec_err, not_impl_err, plan_err, Result, ScalarValue}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, - TIMEZONE_WILDCARD, + ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; +use chrono::{DateTime, Datelike, Duration, Months, TimeDelta, Utc}; + #[derive(Debug)] pub struct DateBinFunc { signature: Signature, @@ -146,8 +147,21 @@ impl ScalarUDFImpl for DateBinFunc { } } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![None, Some(true)])) + fn monotonicity(&self, input: &[ExprProperties]) -> Result { + // The DATE_BIN function preserves the order of its second argument. + let step = &input[0]; + let date_value = &input[1]; + let reference = input.get(2); + + if step.sort_properties.eq(&SortProperties::Singleton) + && reference + .map(|r| r.sort_properties.eq(&SortProperties::Singleton)) + .unwrap_or(true) + { + Ok(date_value.sort_properties) + } else { + Ok(SortProperties::Unordered) + } } } @@ -425,16 +439,16 @@ fn date_bin_impl( mod tests { use std::sync::Arc; + use crate::datetime::date_bin::{date_bin_nanos_interval, DateBinFunc}; use arrow::array::types::TimestampNanosecondType; use arrow::array::{IntervalDayTimeArray, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::{DataType, TimeUnit}; - use chrono::TimeDelta; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use crate::datetime::date_bin::{date_bin_nanos_interval, DateBinFunc}; + use chrono::TimeDelta; #[test] fn test_date_bin() { diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 0414bf9c2a26..ba5db567a025 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -29,19 +29,18 @@ use arrow::array::types::{ TimestampNanosecondType, TimestampSecondType, }; use arrow::array::{Array, PrimitiveArray}; -use arrow::datatypes::DataType::{Null, Timestamp, Utf8}; -use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; -use arrow::datatypes::{DataType, TimeUnit}; -use chrono::{ - DateTime, Datelike, Duration, LocalResult, NaiveDateTime, Offset, TimeDelta, Timelike, -}; - +use arrow::datatypes::DataType::{self, Null, Timestamp, Utf8}; +use arrow::datatypes::TimeUnit::{self, Microsecond, Millisecond, Nanosecond, Second}; use datafusion_common::cast::as_primitive_array; use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, - TIMEZONE_WILDCARD, + ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, +}; + +use chrono::{ + DateTime, Datelike, Duration, LocalResult, NaiveDateTime, Offset, TimeDelta, Timelike, }; #[derive(Debug)] @@ -205,8 +204,16 @@ impl ScalarUDFImpl for DateTruncFunc { &self.aliases } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![None, Some(true)])) + fn monotonicity(&self, input: &[ExprProperties]) -> Result { + // The DATE_TRUNC function preserves the order of its second argument. + let precision = &input[0]; + let date_value = &input[1]; + + if precision.sort_properties.eq(&SortProperties::Singleton) { + Ok(date_value.sort_properties) + } else { + Ok(SortProperties::Unordered) + } } } @@ -410,7 +417,10 @@ fn parse_tz(tz: &Option>) -> Result> { #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::datetime::date_trunc::{date_trunc_coarse, DateTruncFunc}; + use arrow::array::cast::as_primitive_array; use arrow::array::types::TimestampNanosecondType; use arrow::array::TimestampNanosecondArray; @@ -418,7 +428,6 @@ mod tests { use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use std::sync::Arc; #[test] fn date_trunc_test() { diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index a7bcca62944c..af878b4505bc 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -670,6 +670,10 @@ mod tests { parse_timestamp_formatted("09-08-2020 13/42/29", "%m-%d-%Y %H/%M/%S") .unwrap() ); + assert_eq!( + 1642896000000000000, + parse_timestamp_formatted("2022-01-23", "%Y-%m-%d").unwrap() + ); } fn parse_timestamp_formatted(s: &str, format: &str) -> Result { diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 5ee47bd3e8eb..2f14e881d1d8 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -89,7 +89,6 @@ macro_rules! make_udf_function { /// The rationale for providing stub functions is to help users to configure datafusion /// properly (so they get an error telling them why a function is not available) /// instead of getting a cryptic "no function found" message at runtime. - macro_rules! make_stub_package { ($name:ident, $feature:literal) => { #[cfg(not(feature = $feature))] @@ -115,7 +114,6 @@ macro_rules! make_stub_package { /// $ARGS_TYPE: the type of array to cast the argument to /// $RETURN_TYPE: the type of array to return /// $FUNC: the function to apply to each element of $ARG -/// macro_rules! make_function_scalar_inputs_return_type { ($ARG: expr, $NAME:expr, $ARG_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ let arg = downcast_arg!($ARG, $NAME, $ARG_TYPE); @@ -162,14 +160,14 @@ macro_rules! make_math_unary_udf { make_udf_function!($NAME::$UDF, $GNAME, $NAME); mod $NAME { + use std::any::Any; + use std::sync::Arc; + use arrow::array::{ArrayRef, Float32Array, Float64Array}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, DataFusionError, Result}; - use datafusion_expr::{ - ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, - }; - use std::any::Any; - use std::sync::Arc; + use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] pub struct $UDF { @@ -211,8 +209,11 @@ macro_rules! make_math_unary_udf { } } - fn monotonicity(&self) -> Result> { - Ok($MONOTONICITY) + fn monotonicity( + &self, + input: &[ExprProperties], + ) -> Result { + $MONOTONICITY(input) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -266,15 +267,15 @@ macro_rules! make_math_binary_udf { make_udf_function!($NAME::$UDF, $GNAME, $NAME); mod $NAME { + use std::any::Any; + use std::sync::Arc; + use arrow::array::{ArrayRef, Float32Array, Float64Array}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, DataFusionError, Result}; + use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::*; - use datafusion_expr::{ - ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, - }; - use std::any::Any; - use std::sync::Arc; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] pub struct $UDF { @@ -318,8 +319,11 @@ macro_rules! make_math_binary_udf { } } - fn monotonicity(&self) -> Result> { - Ok($MONOTONICITY) + fn monotonicity( + &self, + input: &[ExprProperties], + ) -> Result { + $MONOTONICITY(input) } fn invoke(&self, args: &[ColumnarValue]) -> Result { diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index e05dc8665285..a752102913ba 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -17,23 +17,20 @@ //! math expressions -use arrow::array::Decimal128Array; -use arrow::array::Decimal256Array; -use arrow::array::Int16Array; -use arrow::array::Int32Array; -use arrow::array::Int64Array; -use arrow::array::Int8Array; -use arrow::datatypes::DataType; -use datafusion_common::{exec_err, not_impl_err}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::ColumnarValue; - -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::error::ArrowError; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; +use arrow::array::{ + ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, +}; +use arrow::datatypes::DataType; +use arrow::error::ArrowError; +use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + type MathArrayFunction = fn(&Vec) -> Result; macro_rules! make_abs_function { @@ -170,7 +167,21 @@ impl ScalarUDFImpl for AbsFunc { let input_data_type = args[0].data_type(); let abs_fun = create_abs_function(input_data_type)?; - let arr = abs_fun(&args)?; - Ok(ColumnarValue::Array(arr)) + abs_fun(&args).map(ColumnarValue::Array) + } + + fn monotonicity(&self, input: &[ExprProperties]) -> Result { + // Non-decreasing for x ≥ 0 and symmetrically non-increasing for x ≤ 0. + let arg = &input[0]; + let range = &arg.range; + let zero_point = Interval::make_zero(&range.lower().data_type())?; + + if range.gt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else if range.lt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(-arg.sort_properties) + } else { + Ok(SortProperties::Unordered) + } } } diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index e6c698ad1a80..8c1e8ac8fea3 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -17,6 +17,12 @@ //! Math function: `log()`. +use std::any::Any; +use std::sync::Arc; + +use super::power::PowerFunc; + +use arrow::array::{ArrayRef, Float32Array, Float64Array}; use arrow::datatypes::DataType; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, @@ -24,15 +30,9 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{lit, ColumnarValue, Expr, FuncMonotonicity, ScalarUDF}; - -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use datafusion_expr::TypeSignature::*; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{lit, ColumnarValue, Expr, ScalarUDF, TypeSignature::*}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; - -use super::power::PowerFunc; #[derive(Debug)] pub struct LogFunc { @@ -81,8 +81,23 @@ impl ScalarUDFImpl for LogFunc { } } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![Some(true), Some(false)])) + fn monotonicity(&self, input: &[ExprProperties]) -> Result { + match (input[0].sort_properties, input[1].sort_properties) { + (first @ SortProperties::Ordered(value), SortProperties::Ordered(base)) + if !value.descending && base.descending + || value.descending && !base.descending => + { + Ok(first) + } + ( + first @ (SortProperties::Ordered(_) | SortProperties::Singleton), + SortProperties::Singleton, + ) => Ok(first), + (SortProperties::Singleton, second @ SortProperties::Ordered(_)) => { + Ok(-second) + } + _ => Ok(SortProperties::Unordered), + } } // Support overloaded log(base, x) and log(x) which defaults to log(10, x) @@ -213,14 +228,13 @@ fn is_pow(func: &ScalarUDF) -> bool { mod tests { use std::collections::HashMap; - use datafusion_common::{ - cast::{as_float32_array, as_float64_array}, - DFSchema, - }; - use datafusion_expr::{execution_props::ExecutionProps, simplify::SimplifyContext}; - use super::*; + use datafusion_common::cast::{as_float32_array, as_float64_array}; + use datafusion_common::DFSchema; + use datafusion_expr::execution_props::ExecutionProps; + use datafusion_expr::simplify::SimplifyContext; + #[test] fn test_log_f64() { let args = [ diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index b6e8d26b6460..6c26ce79d0a5 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -17,9 +17,12 @@ //! "math" DataFusion functions -use datafusion_expr::ScalarUDF; use std::sync::Arc; +use crate::math::monotonicity::*; + +use datafusion_expr::ScalarUDF; + pub mod abs; pub mod cot; pub mod factorial; @@ -27,6 +30,7 @@ pub mod gcd; pub mod iszero; pub mod lcm; pub mod log; +pub mod monotonicity; pub mod nans; pub mod nanvl; pub mod pi; @@ -37,42 +41,60 @@ pub mod trunc; // Create UDFs make_udf_function!(abs::AbsFunc, ABS, abs); -make_math_unary_udf!(AcosFunc, ACOS, acos, acos, None); -make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, Some(vec![Some(true)])); -make_math_unary_udf!(AsinFunc, ASIN, asin, asin, None); -make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh, Some(vec![Some(true)])); -make_math_unary_udf!(AtanFunc, ATAN, atan, atan, Some(vec![Some(true)])); -make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, Some(vec![Some(true)])); -make_math_binary_udf!(Atan2, ATAN2, atan2, atan2, Some(vec![Some(true)])); -make_math_unary_udf!(CbrtFunc, CBRT, cbrt, cbrt, None); -make_math_unary_udf!(CeilFunc, CEIL, ceil, ceil, Some(vec![Some(true)])); -make_math_unary_udf!(CosFunc, COS, cos, cos, None); -make_math_unary_udf!(CoshFunc, COSH, cosh, cosh, None); +make_math_unary_udf!(AcosFunc, ACOS, acos, acos, super::acos_monotonicity); +make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, super::acosh_monotonicity); +make_math_unary_udf!(AsinFunc, ASIN, asin, asin, super::asin_monotonicity); +make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh, super::asinh_monotonicity); +make_math_unary_udf!(AtanFunc, ATAN, atan, atan, super::atan_monotonicity); +make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, super::atanh_monotonicity); +make_math_binary_udf!(Atan2, ATAN2, atan2, atan2, super::atan2_monotonicity); +make_math_unary_udf!(CbrtFunc, CBRT, cbrt, cbrt, super::cbrt_monotonicity); +make_math_unary_udf!(CeilFunc, CEIL, ceil, ceil, super::ceil_monotonicity); +make_math_unary_udf!(CosFunc, COS, cos, cos, super::cos_monotonicity); +make_math_unary_udf!(CoshFunc, COSH, cosh, cosh, super::cosh_monotonicity); make_udf_function!(cot::CotFunc, COT, cot); -make_math_unary_udf!(DegreesFunc, DEGREES, degrees, to_degrees, None); -make_math_unary_udf!(ExpFunc, EXP, exp, exp, Some(vec![Some(true)])); +make_math_unary_udf!( + DegreesFunc, + DEGREES, + degrees, + to_degrees, + super::degrees_monotonicity +); +make_math_unary_udf!(ExpFunc, EXP, exp, exp, super::exp_monotonicity); make_udf_function!(factorial::FactorialFunc, FACTORIAL, factorial); -make_math_unary_udf!(FloorFunc, FLOOR, floor, floor, Some(vec![Some(true)])); +make_math_unary_udf!(FloorFunc, FLOOR, floor, floor, super::floor_monotonicity); make_udf_function!(log::LogFunc, LOG, log); make_udf_function!(gcd::GcdFunc, GCD, gcd); make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(iszero::IsZeroFunc, ISZERO, iszero); make_udf_function!(lcm::LcmFunc, LCM, lcm); -make_math_unary_udf!(LnFunc, LN, ln, ln, Some(vec![Some(true)])); -make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); -make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); +make_math_unary_udf!(LnFunc, LN, ln, ln, super::ln_monotonicity); +make_math_unary_udf!(Log2Func, LOG2, log2, log2, super::log2_monotonicity); +make_math_unary_udf!(Log10Func, LOG10, log10, log10, super::log10_monotonicity); make_udf_function!(nanvl::NanvlFunc, NANVL, nanvl); make_udf_function!(pi::PiFunc, PI, pi); make_udf_function!(power::PowerFunc, POWER, power); -make_math_unary_udf!(RadiansFunc, RADIANS, radians, to_radians, None); +make_math_unary_udf!( + RadiansFunc, + RADIANS, + radians, + to_radians, + super::radians_monotonicity +); make_udf_function!(random::RandomFunc, RANDOM, random); make_udf_function!(round::RoundFunc, ROUND, round); -make_math_unary_udf!(SignumFunc, SIGNUM, signum, signum, None); -make_math_unary_udf!(SinFunc, SIN, sin, sin, None); -make_math_unary_udf!(SinhFunc, SINH, sinh, sinh, None); -make_math_unary_udf!(SqrtFunc, SQRT, sqrt, sqrt, None); -make_math_unary_udf!(TanFunc, TAN, tan, tan, None); -make_math_unary_udf!(TanhFunc, TANH, tanh, tanh, None); +make_math_unary_udf!( + SignumFunc, + SIGNUM, + signum, + signum, + super::signum_monotonicity +); +make_math_unary_udf!(SinFunc, SIN, sin, sin, super::sin_monotonicity); +make_math_unary_udf!(SinhFunc, SINH, sinh, sinh, super::sinh_monotonicity); +make_math_unary_udf!(SqrtFunc, SQRT, sqrt, sqrt, super::sqrt_monotonicity); +make_math_unary_udf!(TanFunc, TAN, tan, tan, super::tan_monotonicity); +make_math_unary_udf!(TanhFunc, TANH, tanh, tanh, super::tanh_monotonicity); make_udf_function!(trunc::TruncFunc, TRUNC, trunc); pub mod expr_fn { diff --git a/datafusion/functions/src/math/monotonicity.rs b/datafusion/functions/src/math/monotonicity.rs new file mode 100644 index 000000000000..5ce5654ae79e --- /dev/null +++ b/datafusion/functions/src/math/monotonicity.rs @@ -0,0 +1,241 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; + +fn symmetric_unit_interval(data_type: &DataType) -> Result { + Interval::try_new( + ScalarValue::new_negative_one(data_type)?, + ScalarValue::new_one(data_type)?, + ) +} + +/// Non-increasing on the interval \[−1, 1\], undefined otherwise. +pub fn acos_monotonicity(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let valid_domain = symmetric_unit_interval(&range.lower().data_type())?; + + if valid_domain.contains(range)? == Interval::CERTAINLY_TRUE { + Ok(-arg.sort_properties) + } else { + exec_err!("Input range of ACOS contains out-of-domain values") + } +} + +/// Non-decreasing for x ≥ 1, undefined otherwise. +pub fn acosh_monotonicity(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let valid_domain = Interval::try_new( + ScalarValue::new_one(&range.lower().data_type())?, + ScalarValue::try_from(&range.upper().data_type())?, + )?; + + if valid_domain.contains(range)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else { + exec_err!("Input range of ACOSH contains out-of-domain values") + } +} + +/// Non-decreasing on the interval \[−1, 1\], undefined otherwise. +pub fn asin_monotonicity(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let valid_domain = symmetric_unit_interval(&range.lower().data_type())?; + + if valid_domain.contains(range)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else { + exec_err!("Input range of ASIN contains out-of-domain values") + } +} + +/// Non-decreasing for all real numbers. +pub fn asinh_monotonicity(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +/// Non-decreasing for all real numbers. +pub fn atan_monotonicity(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +/// Non-decreasing on the interval \[−1, 1\], undefined otherwise. +pub fn atanh_monotonicity(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let valid_domain = symmetric_unit_interval(&range.lower().data_type())?; + + if valid_domain.contains(range)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else { + exec_err!("Input range of ATANH contains out-of-domain values") + } +} + +/// Monotonicity depends on the quadrant. +// TODO: Implement monotonicity of the ATAN2 function. +pub fn atan2_monotonicity(_input: &[ExprProperties]) -> Result { + Ok(SortProperties::Unordered) +} + +/// Non-decreasing for all real numbers. +pub fn cbrt_monotonicity(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +/// Non-decreasing for all real numbers. +pub fn ceil_monotonicity(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +/// Non-increasing on \[0, π\] and then non-decreasing on \[π, 2π\]. +/// This pattern repeats periodically with a period of 2π. +// TODO: Implement monotonicity of the ATAN2 function. +pub fn cos_monotonicity(_input: &[ExprProperties]) -> Result { + Ok(SortProperties::Unordered) +} + +/// Non-decreasing for x ≥ 0 and symmetrically non-increasing for x ≤ 0. +pub fn cosh_monotonicity(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let zero_point = Interval::make_zero(&range.lower().data_type())?; + + if range.gt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else if range.lt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(-arg.sort_properties) + } else { + Ok(SortProperties::Unordered) + } +} + +/// Non-decreasing function that converts radians to degrees. +pub fn degrees_monotonicity(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +/// Non-decreasing for all real numbers. +pub fn exp_monotonicity(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +/// Non-decreasing for all real numbers. +pub fn floor_monotonicity(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +/// Non-decreasing for x ≥ 0, undefined otherwise. +pub fn ln_monotonicity(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let zero_point = Interval::make_zero(&range.lower().data_type())?; + + if range.gt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else { + exec_err!("Input range of LN contains out-of-domain values") + } +} + +/// Non-decreasing for x ≥ 0, undefined otherwise. +pub fn log2_monotonicity(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let zero_point = Interval::make_zero(&range.lower().data_type())?; + + if range.gt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else { + exec_err!("Input range of LOG2 contains out-of-domain values") + } +} + +/// Non-decreasing for x ≥ 0, undefined otherwise. +pub fn log10_monotonicity(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let zero_point = Interval::make_zero(&range.lower().data_type())?; + + if range.gt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else { + exec_err!("Input range of LOG10 contains out-of-domain values") + } +} + +/// Non-decreasing for all real numbers x. +pub fn radians_monotonicity(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +/// Non-decreasing for all real numbers x. +pub fn signum_monotonicity(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +/// Non-decreasing on \[0, π\] and then non-increasing on \[π, 2π\]. +/// This pattern repeats periodically with a period of 2π. +// TODO: Implement monotonicity of the SIN function. +pub fn sin_monotonicity(_input: &[ExprProperties]) -> Result { + Ok(SortProperties::Unordered) +} + +/// Non-decreasing for all real numbers. +pub fn sinh_monotonicity(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +/// Non-decreasing for x ≥ 0, undefined otherwise. +pub fn sqrt_monotonicity(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let zero_point = Interval::make_zero(&range.lower().data_type())?; + + if range.gt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else { + exec_err!("Input range of SQRT contains out-of-domain values") + } +} + +/// Non-decreasing between vertical asymptotes at x = k * π ± π / 2 for any +/// integer k. +// TODO: Implement monotonicity of the TAN function. +pub fn tan_monotonicity(_input: &[ExprProperties]) -> Result { + Ok(SortProperties::Unordered) +} + +/// Non-decreasing for all real numbers. +pub fn tanh_monotonicity(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} diff --git a/datafusion/functions/src/math/pi.rs b/datafusion/functions/src/math/pi.rs index f9403e411fe2..60c94b6ca622 100644 --- a/datafusion/functions/src/math/pi.rs +++ b/datafusion/functions/src/math/pi.rs @@ -19,10 +19,9 @@ use std::any::Any; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Float64; - use datafusion_common::{not_impl_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, FuncMonotonicity, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] pub struct PiFunc { @@ -70,7 +69,8 @@ impl ScalarUDFImpl for PiFunc { )))) } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![Some(true)])) + fn monotonicity(&self, _input: &[ExprProperties]) -> Result { + // This function returns a constant value. + Ok(SortProperties::Singleton) } } diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index f4a163137a35..600f4fd5472a 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -18,15 +18,15 @@ use std::any::Any; use std::sync::Arc; +use crate::utils::make_scalar_function; + use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Float32, Float64}; - -use crate::utils::make_scalar_function; use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, FuncMonotonicity}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] pub struct RoundFunc { @@ -80,8 +80,19 @@ impl ScalarUDFImpl for RoundFunc { make_scalar_function(round, vec![])(args) } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![Some(true)])) + fn monotonicity(&self, input: &[ExprProperties]) -> Result { + // round preserves the order of the first argument + let value = &input[0]; + let precision = input.get(1); + + if precision + .map(|r| r.sort_properties.eq(&SortProperties::Singleton)) + .unwrap_or(true) + { + Ok(value.sort_properties) + } else { + Ok(SortProperties::Unordered) + } } } @@ -179,10 +190,12 @@ pub fn round(args: &[ArrayRef]) -> Result { #[cfg(test)] mod test { + use std::sync::Arc; + use crate::math::round::round; + use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use std::sync::Arc; #[test] fn test_round_f32() { diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index 6f88099889cc..0c4d38564b9f 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -18,16 +18,16 @@ use std::any::Any; use std::sync::Arc; +use crate::utils::make_scalar_function; + use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Float32, Float64}; - -use crate::utils::make_scalar_function; use datafusion_common::ScalarValue::Int64; use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, FuncMonotonicity}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] pub struct TruncFunc { @@ -86,8 +86,19 @@ impl ScalarUDFImpl for TruncFunc { make_scalar_function(trunc, vec![])(args) } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![Some(true)])) + fn monotonicity(&self, input: &[ExprProperties]) -> Result { + // trunc preserves the order of the first argument + let value = &input[0]; + let precision = input.get(1); + + if precision + .map(|r| r.sort_properties.eq(&SortProperties::Singleton)) + .unwrap_or(true) + { + Ok(value.sort_properties) + } else { + Ok(SortProperties::Unordered) + } } } @@ -156,10 +167,12 @@ fn compute_truncate64(x: f64, y: i64) -> f64 { #[cfg(test)] mod test { + use std::sync::Arc; + use crate::math::trunc::trunc; + use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use std::sync::Arc; #[test] fn test_truncate_32() { diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 5c12d4559e74..884db24d9ec8 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -28,12 +28,44 @@ make_udf_function!( REGEXP_REPLACE, regexp_replace ); -export_functions!(( - regexp_match, - input_arg1 input_arg2, - "returns a list of regular expression matches in a string. " -),( - regexp_like, - input_arg1 input_arg2, - "Returns true if a has at least one match in a string,false otherwise." -),(regexp_replace, arg1 arg2 arg3 arg4, "Replaces substrings in a string that match")); + +pub mod expr_fn { + use datafusion_expr::Expr; + + /// Returns a list of regular expression matches in a string. + pub fn regexp_match(values: Expr, regex: Expr, flags: Option) -> Expr { + let mut args = vec![values, regex]; + if let Some(flags) = flags { + args.push(flags); + }; + super::regexp_match().call(args) + } + + /// Returns true if a has at least one match in a string, false otherwise. + pub fn regexp_like(values: Expr, regex: Expr, flags: Option) -> Expr { + let mut args = vec![values, regex]; + if let Some(flags) = flags { + args.push(flags); + }; + super::regexp_like().call(args) + } + + /// Replaces substrings in a string that match. + pub fn regexp_replace( + string: Expr, + pattern: Expr, + replacement: Expr, + flags: Option, + ) -> Expr { + let mut args = vec![string, pattern, replacement]; + if let Some(flags) = flags { + args.push(flags); + }; + super::regexp_replace().call(args) + } +} + +#[doc = r" Return a list of all functions in this package"] +pub fn functions() -> Vec> { + vec![regexp_match(), regexp_like(), regexp_replace()] +} diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index a607d49ef967..dfbd5f5632ee 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -25,7 +25,9 @@ use datafusion_expr::expr::{ AggregateFunction, AggregateFunctionDefinition, WindowFunction, }; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; +use datafusion_expr::{ + aggregate_function, lit, Expr, LogicalPlan, WindowFunctionDefinition, +}; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// @@ -54,23 +56,37 @@ fn is_wildcard(expr: &Expr) -> bool { } fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { - matches!( - &aggregate_function.func_def, - AggregateFunctionDefinition::BuiltIn( - datafusion_expr::aggregate_function::AggregateFunction::Count, - ) - ) && aggregate_function.args.len() == 1 - && is_wildcard(&aggregate_function.args[0]) + match aggregate_function { + AggregateFunction { + func_def: AggregateFunctionDefinition::UDF(udf), + args, + .. + } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => true, + AggregateFunction { + func_def: + AggregateFunctionDefinition::BuiltIn( + datafusion_expr::aggregate_function::AggregateFunction::Count, + ), + args, + .. + } if args.len() == 1 && is_wildcard(&args[0]) => true, + _ => false, + } } fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { - matches!( - &window_function.fun, + let args = &window_function.args; + match window_function.fun { WindowFunctionDefinition::AggregateFunction( - datafusion_expr::aggregate_function::AggregateFunction::Count, - ) - ) && window_function.args.len() == 1 - && is_wildcard(&window_function.args[0]) + aggregate_function::AggregateFunction::Count, + ) if args.len() == 1 && is_wildcard(&args[0]) => true, + WindowFunctionDefinition::AggregateUDF(ref udaf) + if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => + { + true + } + _ => false, + } } fn analyze_internal(plan: LogicalPlan) -> Result> { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 60b81aff9aaa..0f1f3ba7e729 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, @@ -31,8 +31,8 @@ use datafusion_expr::expr::{ self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, WindowFunction, }; -use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ comparison_coercion, get_input_types, like_coercion, @@ -52,6 +52,7 @@ use datafusion_expr::{ }; use crate::analyzer::AnalyzerRule; +use crate::utils::NamePreserver; #[derive(Default)] pub struct TypeCoercion {} @@ -68,26 +69,28 @@ impl AnalyzerRule for TypeCoercion { } fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - analyze_internal(&DFSchema::empty(), &plan) + let empty_schema = DFSchema::empty(); + + let transformed_plan = plan + .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))? + .data; + + Ok(transformed_plan) } } +/// use the external schema to handle the correlated subqueries case +/// +/// Assumes that children have already been optimized fn analyze_internal( - // use the external schema to handle the correlated subqueries case external_schema: &DFSchema, - plan: &LogicalPlan, -) -> Result { - // optimize child plans first - let new_inputs = plan - .inputs() - .iter() - .map(|p| analyze_internal(external_schema, p)) - .collect::>>()?; + plan: LogicalPlan, +) -> Result> { // get schema representing all available input fields. This is used for data type // resolution only, so order does not matter here - let mut schema = merge_schema(new_inputs.iter().collect()); + let mut schema = merge_schema(plan.inputs()); - if let LogicalPlan::TableScan(ts) = plan { + if let LogicalPlan::TableScan(ts) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), &ts.source.schema(), @@ -100,25 +103,75 @@ fn analyze_internal( // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3) schema.merge(external_schema); - let mut expr_rewrite = TypeCoercionRewriter { schema: &schema }; - - let new_expr = plan - .expressions() - .into_iter() - .map(|expr| { - // ensure aggregate names don't change: - // https://github.com/apache/datafusion/issues/3555 - rewrite_preserving_name(expr, &mut expr_rewrite) - }) - .collect::>>()?; - - plan.with_new_exprs(new_expr, new_inputs) + let mut expr_rewrite = TypeCoercionRewriter::new(&schema); + + let name_preserver = NamePreserver::new(&plan); + // apply coercion rewrite all expressions in the plan individually + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr)?; + expr.rewrite(&mut expr_rewrite)? + .map_data(|expr| original_name.restore(expr)) + })? + // coerce join expressions specially + .map_data(|plan| expr_rewrite.coerce_joins(plan))? + // recompute the schema after the expressions have been rewritten as the types may have changed + .map_data(|plan| plan.recompute_schema()) } pub(crate) struct TypeCoercionRewriter<'a> { pub(crate) schema: &'a DFSchema, } +impl<'a> TypeCoercionRewriter<'a> { + fn new(schema: &'a DFSchema) -> Self { + Self { schema } + } + + /// Coerce join equality expressions + /// + /// Joins must be treated specially as their equality expressions are stored + /// as a parallel list of left and right expressions, rather than a single + /// equality expression + /// + /// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored + /// as a list of `(t1.a, t2.b), (t1.x, t2.y)` + fn coerce_joins(&mut self, plan: LogicalPlan) -> Result { + let LogicalPlan::Join(mut join) = plan else { + return Ok(plan); + }; + + join.on = join + .on + .into_iter() + .map(|(lhs, rhs)| { + // coerce the arguments as though they were a single binary equality + // expression + let (lhs, rhs) = self.coerce_binary_op(lhs, Operator::Eq, rhs)?; + Ok((lhs, rhs)) + }) + .collect::>>()?; + + Ok(LogicalPlan::Join(join)) + } + + fn coerce_binary_op( + &self, + left: Expr, + op: Operator, + right: Expr, + ) -> Result<(Expr, Expr)> { + let (left_type, right_type) = get_input_types( + &left.get_type(self.schema)?, + &op, + &right.get_type(self.schema)?, + )?; + Ok(( + left.cast_to(&left_type, self.schema)?, + right.cast_to(&right_type, self.schema)?, + )) + } +} + impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { type Node = Expr; @@ -131,14 +184,15 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { subquery, outer_ref_columns, }) => { - let new_plan = analyze_internal(self.schema, &subquery)?; + let new_plan = analyze_internal(self.schema, unwrap_arc(subquery))?.data; Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, }))) } Expr::Exists(Exists { subquery, negated }) => { - let new_plan = analyze_internal(self.schema, &subquery.subquery)?; + let new_plan = + analyze_internal(self.schema, unwrap_arc(subquery.subquery))?.data; Ok(Transformed::yes(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(new_plan), @@ -152,7 +206,8 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { subquery, negated, }) => { - let new_plan = analyze_internal(self.schema, &subquery.subquery)?; + let new_plan = + analyze_internal(self.schema, unwrap_arc(subquery.subquery))?.data; let expr_type = expr.get_type(self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( @@ -221,15 +276,11 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { )))) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let (left_type, right_type) = get_input_types( - &left.get_type(self.schema)?, - &op, - &right.get_type(self.schema)?, - )?; + let (left, right) = self.coerce_binary_op(*left, op, *right)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.cast_to(&left_type, self.schema)?), + Box::new(left), op, - Box::new(right.cast_to(&right_type, self.schema)?), + Box::new(right), )))) } Expr::Between(Between { diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 3959223e68c1..b55b1a7f8f2d 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -38,25 +38,71 @@ use datafusion_physical_expr::execution_props::ExecutionProps; /// 'Filter'. It adds the inner reference columns to the 'Projection' or /// 'Aggregate' of the subquery if they are missing, so that they can be /// evaluated by the parent operator as the join condition. +#[derive(Debug)] pub struct PullUpCorrelatedExpr { pub join_filters: Vec, - // mapping from the plan to its holding correlated columns + /// mapping from the plan to its holding correlated columns pub correlated_subquery_cols_map: HashMap>, pub in_predicate_opt: Option, - // indicate whether it is Exists(Not Exists) SubQuery + /// Is this an Exists(Not Exists) SubQuery. Defaults to **FALSE** pub exists_sub_query: bool, - // indicate whether the correlated expressions can pull up or not + /// Can the correlated expressions be pulled up. Defaults to **TRUE** pub can_pull_up: bool, - // indicate whether need to handle the Count bug during the pull up process + /// Do we need to handle [the Count bug] during the pull up process + /// + /// [the Count bug]: https://github.com/apache/datafusion/pull/10500 pub need_handle_count_bug: bool, - // mapping from the plan to its expressions' evaluation result on empty batch + /// mapping from the plan to its expressions' evaluation result on empty batch pub collected_count_expr_map: HashMap, - // pull up having expr, which must be evaluated after the Join + /// pull up having expr, which must be evaluated after the Join pub pull_up_having_expr: Option, } +impl Default for PullUpCorrelatedExpr { + fn default() -> Self { + Self::new() + } +} + +impl PullUpCorrelatedExpr { + pub fn new() -> Self { + Self { + join_filters: vec![], + correlated_subquery_cols_map: HashMap::new(), + in_predicate_opt: None, + exists_sub_query: false, + can_pull_up: true, + need_handle_count_bug: false, + collected_count_expr_map: HashMap::new(), + pull_up_having_expr: None, + } + } + + /// Set if we need to handle [the Count bug] during the pull up process + /// + /// [the Count bug]: https://github.com/apache/datafusion/pull/10500 + pub fn with_need_handle_count_bug(mut self, need_handle_count_bug: bool) -> Self { + self.need_handle_count_bug = need_handle_count_bug; + self + } + + /// Set the in_predicate_opt + pub fn with_in_predicate_opt(mut self, in_predicate_opt: Option) -> Self { + self.in_predicate_opt = in_predicate_opt; + self + } + + /// Set if this is an Exists(Not Exists) SubQuery + pub fn with_exists_sub_query(mut self, exists_sub_query: bool) -> Self { + self.exists_sub_query = exists_sub_query; + self + } +} + /// Used to indicate the unmatched rows from the inner(subquery) table after the left out Join -/// This is used to handle the Count bug +/// This is used to handle [the Count bug] +/// +/// [the Count bug]: https://github.com/apache/datafusion/pull/10500 pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true"; /// Mapping from expr display name to its evaluation result on empty record diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 58fd8557194f..88ce300e5c9a 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -248,16 +248,10 @@ fn build_join( let subquery = query_info.query.subquery.as_ref(); let subquery_alias = alias.next("__correlated_sq"); - let mut pull_up = PullUpCorrelatedExpr { - join_filters: vec![], - correlated_subquery_cols_map: Default::default(), - in_predicate_opt: in_predicate_opt.clone(), - exists_sub_query: in_predicate_opt.is_none(), - can_pull_up: true, - need_handle_count_bug: false, - collected_count_expr_map: Default::default(), - pull_up_having_expr: None, - }; + let mut pull_up = PullUpCorrelatedExpr::new() + .with_in_predicate_opt(in_predicate_opt.clone()) + .with_exists_sub_query(in_predicate_opt.is_none()); + let new_plan = subquery.clone().rewrite(&mut pull_up).data()?; if !pull_up.can_pull_up { return Ok(None); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 57b38bd0d0fd..b684b5490342 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -14,6 +14,7 @@ //! [`PushDownFilter`] applies filters as early as possible +use indexmap::IndexSet; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -23,10 +24,9 @@ use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{ - internal_err, plan_datafusion_err, qualified_name, Column, DFSchema, DFSchemaRef, + internal_err, plan_err, qualified_name, Column, DFSchema, DFSchemaRef, JoinConstraint, Result, }; -use datafusion_expr::expr::Alias; use datafusion_expr::expr_rewriter::replace_col; use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::{ @@ -131,7 +131,8 @@ use crate::{OptimizerConfig, OptimizerRule}; #[derive(Default)] pub struct PushDownFilter {} -/// For a given JOIN logical plan, determine whether each side of the join is preserved. +/// For a given JOIN type, determine whether each side of the join is preserved. +/// /// We say a join side is preserved if the join returns all or a subset of the rows from /// the relevant side, such that each row of the output table directly maps to a row of /// the preserved input table. If a table is not preserved, it can provide extra null rows. @@ -150,44 +151,33 @@ pub struct PushDownFilter {} /// non-preserved side it can be more tricky. /// /// Returns a tuple of booleans - (left_preserved, right_preserved). -fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { - match plan { - LogicalPlan::Join(Join { join_type, .. }) => match join_type { - JoinType::Inner => Ok((true, true)), - JoinType::Left => Ok((true, false)), - JoinType::Right => Ok((false, true)), - JoinType::Full => Ok((false, false)), - // No columns from the right side of the join can be referenced in output - // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)), - // No columns from the left side of the join can be referenced in output - // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)), - }, - LogicalPlan::CrossJoin(_) => Ok((true, true)), - _ => internal_err!("lr_is_preserved only valid for JOIN nodes"), +fn lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> { + match join_type { + JoinType::Inner => Ok((true, true)), + JoinType::Left => Ok((true, false)), + JoinType::Right => Ok((false, true)), + JoinType::Full => Ok((false, false)), + // No columns from the right side of the join can be referenced in output + // predicates for semi/anti joins, so whether we specify t/f doesn't matter. + JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)), + // No columns from the left side of the join can be referenced in output + // predicates for semi/anti joins, so whether we specify t/f doesn't matter. + JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)), } } /// For a given JOIN logical plan, determine whether each side of the join is preserved /// in terms on join filtering. -/// /// Predicates from join filter can only be pushed to preserved join side. -fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { - match plan { - LogicalPlan::Join(Join { join_type, .. }) => match join_type { - JoinType::Inner => Ok((true, true)), - JoinType::Left => Ok((false, true)), - JoinType::Right => Ok((true, false)), - JoinType::Full => Ok((false, false)), - JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)), - JoinType::LeftAnti => Ok((false, true)), - JoinType::RightAnti => Ok((true, false)), - }, - LogicalPlan::CrossJoin(_) => { - internal_err!("on_lr_is_preserved cannot be applied to CROSSJOIN nodes") - } - _ => internal_err!("on_lr_is_preserved only valid for JOIN nodes"), +fn on_lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> { + match join_type { + JoinType::Inner => Ok((true, true)), + JoinType::Left => Ok((false, true)), + JoinType::Right => Ok((true, false)), + JoinType::Full => Ok((false, false)), + JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)), + JoinType::LeftAnti => Ok((false, true)), + JoinType::RightAnti => Ok((true, false)), } } @@ -400,23 +390,20 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option, - infer_predicates: Vec, - join_plan: &LogicalPlan, - left: &LogicalPlan, - right: &LogicalPlan, + inferred_join_predicates: Vec, + mut join: Join, on_filter: Vec, - is_inner_join: bool, ) -> Result> { - let on_filter_empty = on_filter.is_empty(); + let is_inner_join = join.join_type == JoinType::Inner; // Get pushable predicates from current optimizer state - let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?; + let (left_preserved, right_preserved) = lr_is_preserved(join.join_type)?; // The predicates can be divided to three categories: // 1) can push through join to its children(left or right) // 2) can be converted to join conditions if the join type is Inner // 3) should be kept as filter conditions - let left_schema = left.schema(); - let right_schema = right.schema(); + let left_schema = join.left.schema(); + let right_schema = join.right.schema(); let mut left_push = vec![]; let mut right_push = vec![]; let mut keep_predicates = vec![]; @@ -438,7 +425,7 @@ fn push_down_all_join( } // For infer predicates, if they can not push through join, just drop them - for predicate in infer_predicates { + for predicate in inferred_join_predicates { if left_preserved && can_pushdown_join_predicate(&predicate, left_schema)? { left_push.push(predicate); } else if right_preserved @@ -449,7 +436,7 @@ fn push_down_all_join( } if !on_filter.is_empty() { - let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join_plan)?; + let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type)?; for on in on_filter { if on_left_preserved && can_pushdown_join_predicate(&on, left_schema)? { left_push.push(on) @@ -474,46 +461,29 @@ fn push_down_all_join( right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema)); } - let left = match conjunction(left_push) { - Some(predicate) => { - LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left.clone()))?) - } - None => left.clone(), - }; - let right = match conjunction(right_push) { - Some(predicate) => { - LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(right.clone()))?) - } - None => right.clone(), - }; - // Create a new Join with the new `left` and `right` - // - // expressions() output for Join is a vector consisting of - // 1. join keys - columns mentioned in ON clause - // 2. optional predicate - in case join filter is not empty, - // it always will be the last element, otherwise result - // vector will contain only join keys (without additional - // element representing filter). - let mut exprs = join_plan.expressions(); - if !on_filter_empty { - exprs.pop(); - } - exprs.extend(join_conditions.into_iter().reduce(Expr::and)); - let plan = join_plan.with_new_exprs(exprs, vec![left, right])?; - - // wrap the join on the filter whose predicates must be kept - match conjunction(keep_predicates) { - Some(predicate) => { - let new_filter_plan = Filter::try_new(predicate, Arc::new(plan))?; - Ok(Transformed::yes(LogicalPlan::Filter(new_filter_plan))) - } - None => Ok(Transformed::no(plan)), + if let Some(predicate) = conjunction(left_push) { + join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?)); } + if let Some(predicate) = conjunction(right_push) { + join.right = + Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?)); + } + + // Add any new join conditions as the non join predicates + join.filter = conjunction(join_conditions); + + // wrap the join on the filter whose predicates must be kept, if any + let plan = LogicalPlan::Join(join); + let plan = if let Some(predicate) = conjunction(keep_predicates) { + LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?) + } else { + plan + }; + Ok(Transformed::yes(plan)) } fn push_down_join( - plan: &LogicalPlan, - join: &Join, + join: Join, parent_predicate: Option<&Expr>, ) -> Result> { // Split the parent predicate into individual conjunctive parts. @@ -526,93 +496,102 @@ fn push_down_join( .as_ref() .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone())); - let mut is_inner_join = false; - let infer_predicates = if join.join_type == JoinType::Inner { - is_inner_join = true; - - // Only allow both side key is column. - let join_col_keys = join - .on - .iter() - .filter_map(|(l, r)| { - let left_col = l.try_as_col().cloned()?; - let right_col = r.try_as_col().cloned()?; - Some((left_col, right_col)) - }) - .collect::>(); - - // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down - // For inner joins, duplicate filters for joined columns so filters can be pushed down - // to both sides. Take the following query as an example: - // - // ```sql - // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 - // ``` - // - // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while - // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. - // - // Join clauses with `Using` constraints also take advantage of this logic to make sure - // predicates reference the shared join columns are pushed to both sides. - // This logic should also been applied to conditions in JOIN ON clause - predicates - .iter() - .chain(on_filters.iter()) - .filter_map(|predicate| { - let mut join_cols_to_replace = HashMap::new(); - - let columns = match predicate.to_columns() { - Ok(columns) => columns, - Err(e) => return Some(Err(e)), - }; + // Are there any new join predicates that can be inferred from the filter expressions? + let inferred_join_predicates = + infer_join_predicates(&join, &predicates, &on_filters)?; - for col in columns.iter() { - for (l, r) in join_col_keys.iter() { - if col == l { - join_cols_to_replace.insert(col, r); - break; - } else if col == r { - join_cols_to_replace.insert(col, l); - break; - } - } - } + if on_filters.is_empty() + && predicates.is_empty() + && inferred_join_predicates.is_empty() + { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } - if join_cols_to_replace.is_empty() { - return None; - } + push_down_all_join(predicates, inferred_join_predicates, join, on_filters) +} - let join_side_predicate = - match replace_col(predicate.clone(), &join_cols_to_replace) { - Ok(p) => p, - Err(e) => { - return Some(Err(e)); - } - }; +/// Extracts any equi-join join predicates from the given filter expressions. +/// +/// Parameters +/// * `join` the join in question +/// +/// * `predicates` the pushed down filter expression +/// +/// * `on_filters` filters from the join ON clause that have not already been +/// identified as join predicates +/// +fn infer_join_predicates( + join: &Join, + predicates: &[Expr], + on_filters: &[Expr], +) -> Result> { + if join.join_type != JoinType::Inner { + return Ok(vec![]); + } - Some(Ok(join_side_predicate)) - }) - .collect::>>()? - } else { - vec![] - }; + // Only allow both side key is column. + let join_col_keys = join + .on + .iter() + .filter_map(|(l, r)| { + let left_col = l.try_as_col()?; + let right_col = r.try_as_col()?; + Some((left_col, right_col)) + }) + .collect::>(); - if on_filters.is_empty() && predicates.is_empty() && infer_predicates.is_empty() { - return Ok(Transformed::no(plan.clone())); - } + // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down + // For inner joins, duplicate filters for joined columns so filters can be pushed down + // to both sides. Take the following query as an example: + // + // ```sql + // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 + // ``` + // + // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while + // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. + // + // Join clauses with `Using` constraints also take advantage of this logic to make sure + // predicates reference the shared join columns are pushed to both sides. + // This logic should also been applied to conditions in JOIN ON clause + predicates + .iter() + .chain(on_filters.iter()) + .filter_map(|predicate| { + let mut join_cols_to_replace = HashMap::new(); + + let columns = match predicate.to_columns() { + Ok(columns) => columns, + Err(e) => return Some(Err(e)), + }; + + for col in columns.iter() { + for (l, r) in join_col_keys.iter() { + if col == *l { + join_cols_to_replace.insert(col, *r); + break; + } else if col == *r { + join_cols_to_replace.insert(col, *l); + break; + } + } + } - match push_down_all_join( - predicates, - infer_predicates, - plan, - &join.left, - &join.right, - on_filters, - is_inner_join, - ) { - Ok(plan) => Ok(Transformed::yes(plan.data)), - Err(e) => Err(e), - } + if join_cols_to_replace.is_empty() { + return None; + } + + let join_side_predicate = + match replace_col(predicate.clone(), &join_cols_to_replace) { + Ok(p) => p, + Err(e) => { + return Some(Err(e)); + } + }; + + Some(Ok(join_side_predicate)) + }) + .collect::>>() } impl OptimizerRule for PushDownFilter { @@ -641,46 +620,57 @@ impl OptimizerRule for PushDownFilter { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - let filter = match plan { - LogicalPlan::Filter(ref filter) => filter, - LogicalPlan::Join(ref join) => return push_down_join(&plan, join, None), - _ => return Ok(Transformed::no(plan)), + if let LogicalPlan::Join(join) = plan { + return push_down_join(join, None); + }; + + let plan_schema = plan.schema().clone(); + + let LogicalPlan::Filter(mut filter) = plan else { + return Ok(Transformed::no(plan)); }; - let child_plan = filter.input.as_ref(); - let new_plan = match child_plan { - LogicalPlan::Filter(ref child_filter) => { - let parents_predicates = split_conjunction(&filter.predicate); - let set: HashSet<&&Expr> = parents_predicates.iter().collect(); + match unwrap_arc(filter.input) { + LogicalPlan::Filter(child_filter) => { + let parents_predicates = split_conjunction_owned(filter.predicate); + // remove duplicated filters + let child_predicates = split_conjunction_owned(child_filter.predicate); let new_predicates = parents_predicates - .iter() - .chain( - split_conjunction(&child_filter.predicate) - .iter() - .filter(|e| !set.contains(e)), - ) - .map(|e| (*e).clone()) + .into_iter() + .chain(child_predicates) + // use IndexSet to remove dupes while preserving predicate order + .collect::>() + .into_iter() .collect::>(); - let new_predicate = conjunction(new_predicates).ok_or_else(|| { - plan_datafusion_err!("at least one expression exists") - })?; + + let Some(new_predicate) = conjunction(new_predicates) else { + return plan_err!("at least one expression exists"); + }; let new_filter = LogicalPlan::Filter(Filter::try_new( new_predicate, - child_filter.input.clone(), + child_filter.input, )?); - self.rewrite(new_filter, _config)?.data + self.rewrite(new_filter, _config) } - LogicalPlan::Repartition(_) - | LogicalPlan::Distinct(_) - | LogicalPlan::Sort(_) => { - let new_filter = plan.with_new_exprs( - plan.expressions(), - vec![child_plan.inputs()[0].clone()], - )?; - child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])? + LogicalPlan::Repartition(repartition) => { + let new_filter = + Filter::try_new(filter.predicate, repartition.input.clone()) + .map(LogicalPlan::Filter)?; + insert_below(LogicalPlan::Repartition(repartition), new_filter) } - LogicalPlan::SubqueryAlias(ref subquery_alias) => { + LogicalPlan::Distinct(distinct) => { + let new_filter = + Filter::try_new(filter.predicate, distinct.input().clone()) + .map(LogicalPlan::Filter)?; + insert_below(LogicalPlan::Distinct(distinct), new_filter) + } + LogicalPlan::Sort(sort) => { + let new_filter = Filter::try_new(filter.predicate, sort.input.clone()) + .map(LogicalPlan::Filter)?; + insert_below(LogicalPlan::Sort(sort), new_filter) + } + LogicalPlan::SubqueryAlias(subquery_alias) => { let mut replace_map = HashMap::new(); for (i, (qualifier, field)) in subquery_alias.input.schema().iter().enumerate() @@ -692,15 +682,15 @@ impl OptimizerRule for PushDownFilter { Expr::Column(Column::new(qualifier.cloned(), field.name())), ); } - let new_predicate = - replace_cols_by_name(filter.predicate.clone(), &replace_map)?; + let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?; + let new_filter = LogicalPlan::Filter(Filter::try_new( new_predicate, subquery_alias.input.clone(), )?); - child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])? + insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter) } - LogicalPlan::Projection(ref projection) => { + LogicalPlan::Projection(projection) => { // A projection is filter-commutable if it do not contain volatile predicates or contain volatile // predicates that are not used in the filter. However, we should re-writes all predicate expressions. // collect projection. @@ -711,10 +701,7 @@ impl OptimizerRule for PushDownFilter { .enumerate() .map(|(i, (qualifier, field))| { // strip alias, as they should not be part of filters - let expr = match &projection.expr[i] { - Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(), - expr => expr.clone(), - }; + let expr = projection.expr[i].clone().unalias(); (qualified_name(qualifier, field.name()), expr) }) @@ -741,23 +728,24 @@ impl OptimizerRule for PushDownFilter { )?); match conjunction(keep_predicates) { - None => child_plan.with_new_exprs( - child_plan.expressions(), - vec![new_filter], - )?, - Some(keep_predicate) => { - let child_plan = child_plan.with_new_exprs( - child_plan.expressions(), - vec![new_filter], - )?; - LogicalPlan::Filter(Filter::try_new( - keep_predicate, - Arc::new(child_plan), - )?) - } + None => insert_below( + LogicalPlan::Projection(projection), + new_filter, + ), + Some(keep_predicate) => insert_below( + LogicalPlan::Projection(projection), + new_filter, + )? + .map_data(|child_plan| { + Filter::try_new(keep_predicate, Arc::new(child_plan)) + .map(LogicalPlan::Filter) + }), } } - None => return Ok(Transformed::no(plan)), + None => { + filter.input = Arc::new(LogicalPlan::Projection(projection)); + Ok(Transformed::no(LogicalPlan::Filter(filter))) + } } } LogicalPlan::Union(ref union) => { @@ -780,12 +768,12 @@ impl OptimizerRule for PushDownFilter { input.clone(), )?))) } - LogicalPlan::Union(Union { + Ok(Transformed::yes(LogicalPlan::Union(Union { inputs, - schema: plan.schema().clone(), - }) + schema: plan_schema.clone(), + }))) } - LogicalPlan::Aggregate(ref agg) => { + LogicalPlan::Aggregate(agg) => { // We can push down Predicate which in groupby_expr. let group_expr_columns = agg .group_expr @@ -818,49 +806,33 @@ impl OptimizerRule for PushDownFilter { .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) .collect::>>()?; - let child = match conjunction(replaced_push_predicates) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( - predicate, - agg.input.clone(), - )?), - None => (*agg.input).clone(), - }; - let new_agg = filter - .input - .with_new_exprs(filter.input.expressions(), vec![child])?; - match conjunction(keep_predicates) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(new_agg), - )?), - None => new_agg, - } - } - LogicalPlan::Join(ref join) => { - push_down_join( - &unwrap_arc(filter.clone().input), - join, - Some(&filter.predicate), - )? - .data + let agg_input = agg.input.clone(); + Transformed::yes(LogicalPlan::Aggregate(agg)) + .transform_data(|new_plan| { + // If we have a filter to push, we push it down to the input of the aggregate + if let Some(predicate) = conjunction(replaced_push_predicates) { + let new_filter = make_filter(predicate, agg_input)?; + insert_below(new_plan, new_filter) + } else { + Ok(Transformed::no(new_plan)) + } + })? + .map_data(|child_plan| { + // if there are any remaining predicates we can't push, add them + // back as a filter + if let Some(predicate) = conjunction(keep_predicates) { + make_filter(predicate, Arc::new(child_plan)) + } else { + Ok(child_plan) + } + }) } - LogicalPlan::CrossJoin(ref cross_join) => { + LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)), + LogicalPlan::CrossJoin(cross_join) => { let predicates = split_conjunction_owned(filter.predicate.clone()); - let join = convert_cross_join_to_inner_join(cross_join.clone())?; - let join_plan = LogicalPlan::Join(join); - let inputs = join_plan.inputs(); - let left = inputs[0]; - let right = inputs[1]; - let plan = push_down_all_join( - predicates, - vec![], - &join_plan, - left, - right, - vec![], - true, - )?; - convert_to_cross_join_if_beneficial(plan.data)? + let join = convert_cross_join_to_inner_join(cross_join)?; + let plan = push_down_all_join(predicates, vec![], join, vec![])?; + convert_to_cross_join_if_beneficial(plan.data) } LogicalPlan::TableScan(ref scan) => { let filter_predicates = split_conjunction(&filter.predicate); @@ -901,25 +873,47 @@ impl OptimizerRule for PushDownFilter { fetch: scan.fetch, }); - match conjunction(new_predicate) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(new_scan), - )?), - None => new_scan, - } + Transformed::yes(new_scan).transform_data(|new_scan| { + if let Some(predicate) = conjunction(new_predicate) { + make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes) + } else { + Ok(Transformed::no(new_scan)) + } + }) } - LogicalPlan::Extension(ref extension_plan) => { + LogicalPlan::Extension(extension_plan) => { let prevent_cols = extension_plan.node.prevent_predicate_push_down_columns(); - let predicates = split_conjunction_owned(filter.predicate.clone()); + // determine if we can push any predicates down past the extension node + + // each element is true for push, false to keep + let predicate_push_or_keep = split_conjunction(&filter.predicate) + .iter() + .map(|expr| { + let cols = expr.to_columns()?; + if cols.iter().any(|c| prevent_cols.contains(&c.name)) { + Ok(false) // No push (keep) + } else { + Ok(true) // push + } + }) + .collect::>>()?; + // all predicates are kept, no changes needed + if predicate_push_or_keep.iter().all(|&x| !x) { + filter.input = Arc::new(LogicalPlan::Extension(extension_plan)); + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + + // going to push some predicates down, so split the predicates let mut keep_predicates = vec![]; let mut push_predicates = vec![]; - for expr in predicates { - let cols = expr.to_columns()?; - if cols.iter().any(|c| prevent_cols.contains(&c.name)) { + for (push, expr) in predicate_push_or_keep + .into_iter() + .zip(split_conjunction_owned(filter.predicate).into_iter()) + { + if !push { keep_predicates.push(expr); } else { push_predicates.push(expr); @@ -941,22 +935,65 @@ impl OptimizerRule for PushDownFilter { None => extension_plan.node.inputs().into_iter().cloned().collect(), }; // extension with new inputs. + let child_plan = LogicalPlan::Extension(extension_plan); let new_extension = child_plan.with_new_exprs(child_plan.expressions(), new_children)?; - match conjunction(keep_predicates) { + let new_plan = match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, Arc::new(new_extension), )?), None => new_extension, - } + }; + Ok(Transformed::yes(new_plan)) } - _ => return Ok(Transformed::no(plan)), - }; + child => { + filter.input = Arc::new(child); + Ok(Transformed::no(LogicalPlan::Filter(filter))) + } + } + } +} + +/// Creates a new LogicalPlan::Filter node. +pub fn make_filter(predicate: Expr, input: Arc) -> Result { + Filter::try_new(predicate, input).map(LogicalPlan::Filter) +} - Ok(Transformed::yes(new_plan)) +/// Replace the existing child of the single input node with `new_child`. +/// +/// Starting: +/// ```text +/// plan +/// child +/// ``` +/// +/// Ending: +/// ```text +/// plan +/// new_child +/// ``` +fn insert_below( + plan: LogicalPlan, + new_child: LogicalPlan, +) -> Result> { + let mut new_child = Some(new_child); + let transformed_plan = plan.map_children(|_child| { + if let Some(new_child) = new_child.take() { + Ok(Transformed::yes(new_child)) + } else { + // already took the new child + internal_err!("node had more than one input") + } + })?; + + // make sure we did the actual replacement + if new_child.is_some() { + return internal_err!("node had no inputs"); } + + Ok(transformed_plan) } impl PushDownFilter { @@ -985,21 +1022,27 @@ fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { /// Converts the given inner join with an empty equality predicate and an /// empty filter condition to a cross join. -fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result { - if let LogicalPlan::Join(join) = &plan { +fn convert_to_cross_join_if_beneficial( + plan: LogicalPlan, +) -> Result> { + match plan { // Can be converted back to cross join - if join.on.is_empty() && join.filter.is_none() { - return LogicalPlanBuilder::from(join.left.as_ref().clone()) - .cross_join(join.right.as_ref().clone())? - .build(); + LogicalPlan::Join(join) if join.on.is_empty() && join.filter.is_none() => { + LogicalPlanBuilder::from(unwrap_arc(join.left)) + .cross_join(unwrap_arc(join.right))? + .build() + .map(Transformed::yes) } - } else if let LogicalPlan::Filter(filter) = &plan { - let new_input = - convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?; - return Filter::try_new(filter.predicate.clone(), Arc::new(new_input)) - .map(LogicalPlan::Filter); + LogicalPlan::Filter(filter) => convert_to_cross_join_if_beneficial(unwrap_arc( + filter.input, + ))? + .transform_data(|child_plan| { + Filter::try_new(filter.predicate, Arc::new(child_plan)) + .map(LogicalPlan::Filter) + .map(Transformed::yes) + }), + plan => Ok(Transformed::no(plan)), } - Ok(plan) } /// replaces columns by its name on the projection. diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 1af246fc556d..b97dff74d979 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -17,16 +17,16 @@ //! [`PushDownLimit`] pushes `LIMIT` earlier in the query plan +use std::cmp::min; use std::sync::Arc; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::Result; -use datafusion_expr::logical_plan::{ - Join, JoinType, Limit, LogicalPlan, Sort, TableScan, Union, -}; -use datafusion_expr::CrossJoin; +use datafusion_common::tree_node::Transformed; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; +use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan}; /// Optimization rule that tries to push down `LIMIT`. /// @@ -45,166 +45,120 @@ impl PushDownLimit { impl OptimizerRule for PushDownLimit { fn try_optimize( &self, - plan: &LogicalPlan, + _plan: &LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - use std::cmp::min; + internal_err!("Should have called PushDownLimit::rewrite") + } + + fn supports_rewrite(&self) -> bool { + true + } - let LogicalPlan::Limit(limit) = plan else { - return Ok(None); + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + let LogicalPlan::Limit(mut limit) = plan else { + return Ok(Transformed::no(plan)); }; - if let LogicalPlan::Limit(child) = &*limit.input { - // Merge the Parent Limit and the Child Limit. - - // Case 0: Parent and Child are disjoint. (child_fetch <= skip) - // Before merging: - // |........skip........|---fetch-->| Parent Limit - // |...child_skip...|---child_fetch-->| Child Limit - // After merging: - // |.........(child_skip + skip).........| - // Before merging: - // |...skip...|------------fetch------------>| Parent Limit - // |...child_skip...|-------------child_fetch------------>| Child Limit - // After merging: - // |....(child_skip + skip)....|---(child_fetch - skip)-->| - - // Case 1: Parent is beyond the range of Child. (skip < child_fetch <= skip + fetch) - // Before merging: - // |...skip...|------------fetch------------>| Parent Limit - // |...child_skip...|-------------child_fetch------------>| Child Limit - // After merging: - // |....(child_skip + skip)....|---(child_fetch - skip)-->| - - // Case 2: Parent is in the range of Child. (skip + fetch < child_fetch) - // Before merging: - // |...skip...|---fetch-->| Parent Limit - // |...child_skip...|-------------child_fetch------------>| Child Limit - // After merging: - // |....(child_skip + skip)....|---fetch-->| - let parent_skip = limit.skip; - let new_fetch = match (limit.fetch, child.fetch) { - (Some(fetch), Some(child_fetch)) => { - Some(min(fetch, child_fetch.saturating_sub(parent_skip))) - } - (Some(fetch), None) => Some(fetch), - (None, Some(child_fetch)) => { - Some(child_fetch.saturating_sub(parent_skip)) - } - (None, None) => None, - }; + let Limit { skip, fetch, input } = limit; + let input = input; + + // Merge the Parent Limit and the Child Limit. + if let LogicalPlan::Limit(child) = input.as_ref() { + let (skip, fetch) = + combine_limit(limit.skip, limit.fetch, child.skip, child.fetch); let plan = LogicalPlan::Limit(Limit { - skip: child.skip + parent_skip, - fetch: new_fetch, - input: Arc::new((*child.input).clone()), + skip, + fetch, + input: Arc::clone(&child.input), }); - return self - .try_optimize(&plan, _config) - .map(|opt_plan| opt_plan.or_else(|| Some(plan))); + + // recursively reapply the rule on the new plan + return self.rewrite(plan, _config); } - let Some(fetch) = limit.fetch else { - return Ok(None); + // no fetch to push, so return the original plan + let Some(fetch) = fetch else { + return Ok(Transformed::no(LogicalPlan::Limit(Limit { + skip, + fetch, + input, + }))); }; - let skip = limit.skip; - match limit.input.as_ref() { - LogicalPlan::TableScan(scan) => { - let limit = if fetch != 0 { fetch + skip } else { 0 }; - let new_fetch = scan.fetch.map(|x| min(x, limit)).or(Some(limit)); + match unwrap_arc(input) { + LogicalPlan::TableScan(mut scan) => { + let rows_needed = if fetch != 0 { fetch + skip } else { 0 }; + let new_fetch = scan + .fetch + .map(|x| min(x, rows_needed)) + .or(Some(rows_needed)); if new_fetch == scan.fetch { - Ok(None) + original_limit(skip, fetch, LogicalPlan::TableScan(scan)) } else { - let new_input = LogicalPlan::TableScan(TableScan { - table_name: scan.table_name.clone(), - source: scan.source.clone(), - projection: scan.projection.clone(), - filters: scan.filters.clone(), - fetch: scan.fetch.map(|x| min(x, limit)).or(Some(limit)), - projected_schema: scan.projected_schema.clone(), - }); - plan.with_new_exprs(plan.expressions(), vec![new_input]) - .map(Some) + // push limit into the table scan itself + scan.fetch = scan + .fetch + .map(|x| min(x, rows_needed)) + .or(Some(rows_needed)); + transformed_limit(skip, fetch, LogicalPlan::TableScan(scan)) } } - LogicalPlan::Union(union) => { - let new_inputs = union + LogicalPlan::Union(mut union) => { + // push limits to each input of the union + union.inputs = union .inputs - .iter() - .map(|x| { - Ok(Arc::new(LogicalPlan::Limit(Limit { - skip: 0, - fetch: Some(fetch + skip), - input: x.clone(), - }))) - }) - .collect::>()?; - let union = LogicalPlan::Union(Union { - inputs: new_inputs, - schema: union.schema.clone(), - }); - plan.with_new_exprs(plan.expressions(), vec![union]) - .map(Some) + .into_iter() + .map(|input| make_arc_limit(0, fetch + skip, input)) + .collect(); + transformed_limit(skip, fetch, LogicalPlan::Union(union)) } - LogicalPlan::CrossJoin(cross_join) => { - let new_left = LogicalPlan::Limit(Limit { - skip: 0, - fetch: Some(fetch + skip), - input: cross_join.left.clone(), - }); - let new_right = LogicalPlan::Limit(Limit { - skip: 0, - fetch: Some(fetch + skip), - input: cross_join.right.clone(), - }); - let new_cross_join = LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(new_left), - right: Arc::new(new_right), - schema: plan.schema().clone(), - }); - plan.with_new_exprs(plan.expressions(), vec![new_cross_join]) - .map(Some) + LogicalPlan::CrossJoin(mut cross_join) => { + // push limit to both inputs + cross_join.left = make_arc_limit(0, fetch + skip, cross_join.left); + cross_join.right = make_arc_limit(0, fetch + skip, cross_join.right); + transformed_limit(skip, fetch, LogicalPlan::CrossJoin(cross_join)) } - LogicalPlan::Join(join) => { - if let Some(new_join) = push_down_join(join, fetch + skip) { - let inputs = vec![LogicalPlan::Join(new_join)]; - plan.with_new_exprs(plan.expressions(), inputs).map(Some) - } else { - Ok(None) - } - } + LogicalPlan::Join(join) => Ok(push_down_join(join, fetch + skip) + .update_data(|join| { + make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join))) + })), - LogicalPlan::Sort(sort) => { + LogicalPlan::Sort(mut sort) => { let new_fetch = { let sort_fetch = skip + fetch; Some(sort.fetch.map(|f| f.min(sort_fetch)).unwrap_or(sort_fetch)) }; if new_fetch == sort.fetch { - Ok(None) + original_limit(skip, fetch, LogicalPlan::Sort(sort)) } else { - let new_sort = LogicalPlan::Sort(Sort { - expr: sort.expr.clone(), - input: sort.input.clone(), - fetch: new_fetch, - }); - plan.with_new_exprs(plan.expressions(), vec![new_sort]) - .map(Some) + sort.fetch = new_fetch; + limit.input = Arc::new(LogicalPlan::Sort(sort)); + Ok(Transformed::yes(LogicalPlan::Limit(limit))) } } - child_plan @ (LogicalPlan::Projection(_) | LogicalPlan::SubqueryAlias(_)) => { + LogicalPlan::Projection(mut proj) => { + // commute + limit.input = Arc::clone(&proj.input); + let new_limit = LogicalPlan::Limit(limit); + proj.input = Arc::new(new_limit); + Ok(Transformed::yes(LogicalPlan::Projection(proj))) + } + LogicalPlan::SubqueryAlias(mut subquery_alias) => { // commute - let new_limit = plan.with_new_exprs( - plan.expressions(), - vec![child_plan.inputs()[0].clone()], - )?; - child_plan - .with_new_exprs(child_plan.expressions(), vec![new_limit]) - .map(Some) + limit.input = Arc::clone(&subquery_alias.input); + let new_limit = LogicalPlan::Limit(limit); + subquery_alias.input = Arc::new(new_limit); + Ok(Transformed::yes(LogicalPlan::SubqueryAlias(subquery_alias))) } - _ => Ok(None), + input => original_limit(skip, fetch, input), } } @@ -217,14 +171,142 @@ impl OptimizerRule for PushDownLimit { } } -fn push_down_join(join: &Join, limit: usize) -> Option { +/// Wrap the input plan with a limit node +/// +/// Original: +/// ```text +/// input +/// ``` +/// +/// Return +/// ```text +/// Limit: skip=skip, fetch=fetch +/// input +/// ``` +fn make_limit(skip: usize, fetch: usize, input: Arc) -> LogicalPlan { + LogicalPlan::Limit(Limit { + skip, + fetch: Some(fetch), + input, + }) +} + +/// Wrap the input plan with a limit node +fn make_arc_limit( + skip: usize, + fetch: usize, + input: Arc, +) -> Arc { + Arc::new(make_limit(skip, fetch, input)) +} + +/// Returns the original limit (non transformed) +fn original_limit( + skip: usize, + fetch: usize, + input: LogicalPlan, +) -> Result> { + Ok(Transformed::no(LogicalPlan::Limit(Limit { + skip, + fetch: Some(fetch), + input: Arc::new(input), + }))) +} + +/// Returns the a transformed limit +fn transformed_limit( + skip: usize, + fetch: usize, + input: LogicalPlan, +) -> Result> { + Ok(Transformed::yes(LogicalPlan::Limit(Limit { + skip, + fetch: Some(fetch), + input: Arc::new(input), + }))) +} + +/// Combines two limits into a single +/// +/// Returns the combined limit `(skip, fetch)` +/// +/// # Case 0: Parent and Child are disjoint. (`child_fetch <= skip`) +/// +/// ```text +/// Before merging: +/// |........skip........|---fetch-->| Parent Limit +/// |...child_skip...|---child_fetch-->| Child Limit +/// ``` +/// +/// After merging: +/// ```text +/// |.........(child_skip + skip).........| +/// ``` +/// +/// Before merging: +/// ```text +/// |...skip...|------------fetch------------>| Parent Limit +/// |...child_skip...|-------------child_fetch------------>| Child Limit +/// ``` +/// +/// After merging: +/// ```text +/// |....(child_skip + skip)....|---(child_fetch - skip)-->| +/// ``` +/// +/// # Case 1: Parent is beyond the range of Child. (`skip < child_fetch <= skip + fetch`) +/// +/// Before merging: +/// ```text +/// |...skip...|------------fetch------------>| Parent Limit +/// |...child_skip...|-------------child_fetch------------>| Child Limit +/// ``` +/// +/// After merging: +/// ```text +/// |....(child_skip + skip)....|---(child_fetch - skip)-->| +/// ``` +/// +/// # Case 2: Parent is in the range of Child. (`skip + fetch < child_fetch`) +/// Before merging: +/// ```text +/// |...skip...|---fetch-->| Parent Limit +/// |...child_skip...|-------------child_fetch------------>| Child Limit +/// ``` +/// +/// After merging: +/// ```text +/// |....(child_skip + skip)....|---fetch-->| +/// ``` +fn combine_limit( + parent_skip: usize, + parent_fetch: Option, + child_skip: usize, + child_fetch: Option, +) -> (usize, Option) { + let combined_skip = child_skip.saturating_add(parent_skip); + + let combined_fetch = match (parent_fetch, child_fetch) { + (Some(parent_fetch), Some(child_fetch)) => { + Some(min(parent_fetch, child_fetch.saturating_sub(parent_skip))) + } + (Some(parent_fetch), None) => Some(parent_fetch), + (None, Some(child_fetch)) => Some(child_fetch.saturating_sub(parent_skip)), + (None, None) => None, + }; + + (combined_skip, combined_fetch) +} + +/// Adds a limit to the inputs of a join, if possible +fn push_down_join(mut join: Join, limit: usize) -> Transformed { use JoinType::*; fn is_no_join_condition(join: &Join) -> bool { join.on.is_empty() && join.filter.is_none() } - let (left_limit, right_limit) = if is_no_join_condition(join) { + let (left_limit, right_limit) = if is_no_join_condition(&join) { match join.join_type { Left | Right | Full => (Some(limit), Some(limit)), LeftAnti | LeftSemi => (Some(limit), None), @@ -239,37 +321,16 @@ fn push_down_join(join: &Join, limit: usize) -> Option { } }; - match (left_limit, right_limit) { - (None, None) => None, - _ => { - let left = match left_limit { - Some(limit) => Arc::new(LogicalPlan::Limit(Limit { - skip: 0, - fetch: Some(limit), - input: join.left.clone(), - })), - None => join.left.clone(), - }; - let right = match right_limit { - Some(limit) => Arc::new(LogicalPlan::Limit(Limit { - skip: 0, - fetch: Some(limit), - input: join.right.clone(), - })), - None => join.right.clone(), - }; - Some(Join { - left, - right, - on: join.on.clone(), - filter: join.filter.clone(), - join_type: join.join_type, - join_constraint: join.join_constraint, - schema: join.schema.clone(), - null_equals_null: join.null_equals_null, - }) - } + if left_limit.is_none() && right_limit.is_none() { + return Transformed::no(join); + } + if let Some(limit) = left_limit { + join.left = make_arc_limit(0, limit, join.left); + } + if let Some(limit) = right_limit { + join.right = make_arc_limit(0, limit, join.right); } + Transformed::yes(join) } #[cfg(test)] diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index b7fce68fb3cc..cb28961497f4 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -29,7 +29,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; -use datafusion_common::{plan_err, Column, Result, ScalarValue}; +use datafusion_common::{internal_err, plan_err, Column, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; @@ -50,7 +50,7 @@ impl ScalarSubqueryToJoin { /// # Arguments /// * `predicate` - A conjunction to split and search /// - /// Returns a tuple (subqueries, rewrite expression) + /// Returns a tuple (subqueries, alias) fn extract_subquery_exprs( &self, predicate: &Expr, @@ -71,19 +71,36 @@ impl ScalarSubqueryToJoin { impl OptimizerRule for ScalarSubqueryToJoin { fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { + internal_err!("Should have called ScalarSubqueryToJoin::rewrite") + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { match plan { LogicalPlan::Filter(filter) => { + // Optimization: skip the rest of the rule and its copies if + // there are no scalar subqueries + if !contains_scalar_subquery(&filter.predicate) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs( &filter.predicate, config.alias_generator(), )?; if subqueries.is_empty() { - // regular filter, no subquery exists clause here - return Ok(None); + return internal_err!("Expected subqueries not found in filter"); } // iterate through all subqueries in predicate, turning each into a left join @@ -94,16 +111,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { { if !expr_check_map.is_empty() { rewrite_expr = rewrite_expr - .clone() .transform_up(|expr| { - if let Expr::Column(col) = &expr { - if let Some(map_expr) = - expr_check_map.get(&col.name) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } + // replace column references with entry in map, if it exists + if let Some(map_expr) = expr + .try_as_col() + .and_then(|col| expr_check_map.get(&col.name)) + { + Ok(Transformed::yes(map_expr.clone())) } else { Ok(Transformed::no(expr)) } @@ -113,15 +127,21 @@ impl OptimizerRule for ScalarSubqueryToJoin { cur_input = optimized_subquery; } else { // if we can't handle all of the subqueries then bail for now - return Ok(None); + return Ok(Transformed::no(LogicalPlan::Filter(filter))); } } let new_plan = LogicalPlanBuilder::from(cur_input) .filter(rewrite_expr)? .build()?; - Ok(Some(new_plan)) + Ok(Transformed::yes(new_plan)) } LogicalPlan::Projection(projection) => { + // Optimization: skip the rest of the rule and its copies if + // there are no scalar subqueries + if !projection.expr.iter().any(contains_scalar_subquery) { + return Ok(Transformed::no(LogicalPlan::Projection(projection))); + } + let mut all_subqueryies = vec![]; let mut expr_to_rewrite_expr_map = HashMap::new(); let mut subquery_to_expr_map = HashMap::new(); @@ -135,8 +155,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { expr_to_rewrite_expr_map.insert(expr, rewrite_exprs); } if all_subqueryies.is_empty() { - // regular projection, no subquery exists clause here - return Ok(None); + return internal_err!("Expected subqueries not found in projection"); } // iterate through all subqueries in predicate, turning each into a left join let mut cur_input = projection.input.as_ref().clone(); @@ -153,14 +172,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { let new_expr = rewrite_expr .clone() .transform_up(|expr| { - if let Expr::Column(col) = &expr { - if let Some(map_expr) = + // replace column references with entry in map, if it exists + if let Some(map_expr) = + expr.try_as_col().and_then(|col| { expr_check_map.get(&col.name) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } + }) + { + Ok(Transformed::yes(map_expr.clone())) } else { Ok(Transformed::no(expr)) } @@ -172,7 +190,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } else { // if we can't handle all of the subqueries then bail for now - return Ok(None); + return Ok(Transformed::no(LogicalPlan::Projection(projection))); } } @@ -190,10 +208,10 @@ impl OptimizerRule for ScalarSubqueryToJoin { let new_plan = LogicalPlanBuilder::from(cur_input) .project(proj_exprs)? .build()?; - Ok(Some(new_plan)) + Ok(Transformed::yes(new_plan)) } - _ => Ok(None), + plan => Ok(Transformed::no(plan)), } } @@ -206,6 +224,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } +/// Returns true if the expression has a scalar subquery somewhere in it +/// false otherwise +fn contains_scalar_subquery(expr: &Expr) -> bool { + expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_)))) + .expect("Inner is always Ok") +} + struct ExtractScalarSubQuery { sub_query_info: Vec<(Subquery, String)>, alias_gen: Arc, @@ -280,16 +305,7 @@ fn build_join( subquery_alias: &str, ) -> Result)>> { let subquery_plan = subquery.subquery.as_ref(); - let mut pull_up = PullUpCorrelatedExpr { - join_filters: vec![], - correlated_subquery_cols_map: Default::default(), - in_predicate_opt: None, - exists_sub_query: false, - can_pull_up: true, - need_handle_count_bug: true, - collected_count_expr_map: Default::default(), - pull_up_having_expr: None, - }; + let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true); let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?; if !pull_up.can_pull_up { return Ok(None); diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 55052542a8bf..455d659fb25e 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1759,7 +1759,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result { mod tests { use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ - function::AggregateFunctionSimplification, interval_arithmetic::Interval, *, + function::{AccumulatorArgs, AggregateFunctionSimplification}, + interval_arithmetic::Interval, + *, }; use std::{ collections::HashMap, @@ -3783,7 +3785,7 @@ mod tests { unimplemented!("not needed for tests") } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { unimplemented!("not needed for testing") } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 10a3a51ec4d8..4334e64082df 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -91,6 +91,28 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { } else if !matches!(fun, Sum | Min | Max) { return Ok(false); } + } else if let Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::UDF(fun), + distinct, + args, + filter, + order_by, + null_treatment: _, + }) = expr + { + if filter.is_some() || order_by.is_some() { + return Ok(false); + } + aggregate_count += 1; + if *distinct { + for e in args { + fields_set.insert(e.canonical_name()); + } + } else if fun.name() != "SUM" && fun.name() != "MIN" && fun.name() != "MAX" { + return Ok(false); + } + } else { + return Ok(false); } } Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1) diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index d2e3414fbfce..da24f335b2f8 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -20,6 +20,7 @@ pub mod utils; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::type_coercion::aggregates::check_arg_count; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, @@ -34,6 +35,7 @@ use self::utils::{down_cast_any_ref, ordering_fields}; /// Creates a physical expression of the UDAF, that includes all necessary type coercion. /// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. +#[allow(clippy::too_many_arguments)] pub fn create_aggregate_expr( fun: &AggregateUDF, input_phy_exprs: &[Arc], @@ -42,6 +44,7 @@ pub fn create_aggregate_expr( schema: &Schema, name: impl Into, ignore_nulls: bool, + is_distinct: bool, ) -> Result> { let input_exprs_types = input_phy_exprs .iter() @@ -71,6 +74,8 @@ pub fn create_aggregate_expr( ordering_req: ordering_req.to_vec(), ignore_nulls, ordering_fields, + is_distinct, + input_type: input_exprs_types[0].clone(), })) } @@ -162,6 +167,8 @@ pub struct AggregateFunctionExpr { ordering_req: LexOrdering, ignore_nulls: bool, ordering_fields: Vec, + is_distinct: bool, + input_type: DataType, } impl AggregateFunctionExpr { @@ -169,6 +176,11 @@ impl AggregateFunctionExpr { pub fn fun(&self) -> &AggregateUDF { &self.fun } + + /// Return if the aggregation is distinct + pub fn is_distinct(&self) -> bool { + self.is_distinct + } } impl AggregateExpr for AggregateFunctionExpr { @@ -182,11 +194,15 @@ impl AggregateExpr for AggregateFunctionExpr { } fn state_fields(&self) -> Result> { - self.fun.state_fields( - self.name(), - self.data_type.clone(), - self.ordering_fields.clone(), - ) + let args = StateFieldsArgs { + name: &self.name, + input_type: &self.input_type, + return_type: &self.data_type, + ordering_fields: &self.ordering_fields, + is_distinct: self.is_distinct, + }; + + self.fun.state_fields(args) } fn field(&self) -> Result { @@ -194,12 +210,15 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_accumulator(&self) -> Result> { - let acc_args = AccumulatorArgs::new( - &self.data_type, - &self.schema, - self.ignore_nulls, - &self.sort_exprs, - ); + let acc_args = AccumulatorArgs { + data_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + sort_exprs: &self.sort_exprs, + is_distinct: self.is_distinct, + input_type: &self.input_type, + args_num: self.args.len(), + }; self.fun.accumulator(acc_args) } @@ -264,7 +283,16 @@ impl AggregateExpr for AggregateFunctionExpr { } fn groups_accumulator_supported(&self) -> bool { - self.fun.groups_accumulator_supported() + let args = AccumulatorArgs { + data_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + sort_exprs: &self.sort_exprs, + is_distinct: self.is_distinct, + input_type: &self.input_type, + args_num: self.args.len(), + }; + self.fun.groups_accumulator_supported(args) } fn create_groups_accumulator(&self) -> Result> { diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index 53e3134a1b05..f335958698ab 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -19,6 +19,5 @@ pub mod aggregate; pub mod expressions; pub mod physical_expr; pub mod sort_expr; -pub mod sort_properties; pub mod tree_node; pub mod utils; diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index a0f8bdf10377..00b3dd725dc2 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -20,17 +20,17 @@ use std::fmt::{Debug, Display}; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use crate::utils::scatter; + use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::ColumnarValue; -use crate::sort_properties::SortProperties; -use crate::utils::scatter; - /// See [create_physical_expr](https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html) /// for examples of creating `PhysicalExpr` from `Expr` pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { @@ -154,17 +154,13 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { /// directly because it must remain object safe. fn dyn_hash(&self, _state: &mut dyn Hasher); - /// The order information of a PhysicalExpr can be estimated from its children. - /// This is especially helpful for projection expressions. If we can ensure that the - /// order of a PhysicalExpr to project matches with the order of SortExec, we can - /// eliminate that SortExecs. - /// - /// By recursively calling this function, we can obtain the overall order - /// information of the PhysicalExpr. Since `SortOptions` cannot fully handle - /// the propagation of unordered columns and literals, the `SortProperties` - /// struct is used. - fn get_ordering(&self, _children: &[SortProperties]) -> SortProperties { - SortProperties::Unordered + /// Calculates the properties of this [`PhysicalExpr`] based on its + /// children's properties (i.e. order and range), recursively aggregating + /// the information from its children. In cases where the [`PhysicalExpr`] + /// has no children (e.g., `Literal` or `Column`), these properties should + /// be specified externally, as the function defaults to unknown properties. + fn get_properties(&self, _children: &[ExprProperties]) -> Result { + Ok(ExprProperties::new_unknown()) } } diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index 459b5a4849cb..601d344e4aac 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -15,13 +15,34 @@ // specific language governing permissions and limitations // under the License. -use arrow::{ - array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}, - compute::{and_kleene, is_not_null, SlicesIterator}, +use std::sync::Arc; + +use crate::{ + physical_expr::PhysicalExpr, sort_expr::PhysicalSortExpr, tree_node::ExprContext, }; -use datafusion_common::Result; -use crate::sort_expr::PhysicalSortExpr; +use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; +use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; +use datafusion_common::Result; +use datafusion_expr::sort_properties::ExprProperties; + +/// Represents a [`PhysicalExpr`] node with associated properties (order and +/// range) in a context where properties are tracked. +pub type ExprPropertiesNode = ExprContext; + +impl ExprPropertiesNode { + /// Constructs a new `ExprPropertiesNode` with unknown properties for a + /// given physical expression. This node initializes with default properties + /// and recursively applies this to all child expressions. + pub fn new_unknown(expr: Arc) -> Self { + let children = expr.children().into_iter().map(Self::new_unknown).collect(); + Self { + expr, + data: ExprProperties::new_unknown(), + children, + } + } +} /// Scatter `truthy` array by boolean mask. When the mask evaluates `true`, next values of `truthy` /// are taken, when the mask evaluates `false` values null values are filled. diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index b8671c39a943..244a44acdcb5 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -153,12 +153,11 @@ impl Accumulator for DistinctArrayAggAccumulator { return Ok(()); } - let array = &states[0]; - - assert_eq!(array.len(), 1, "state array should only include 1 row!"); - // Unwrap outer ListArray then do update batch - let inner_array = array.as_list::().value(0); - self.update_batch(&[inner_array]) + states[0] + .as_list::() + .iter() + .flatten() + .try_for_each(|val| self.update_batch(&[val])) } fn evaluate(&mut self) -> Result { diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index 95ae3207462e..50bd24c487bf 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -1103,509 +1103,3 @@ impl Accumulator for SlidingMinAccumulator { std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::col; - use crate::expressions::tests::{aggregate, aggregate_new}; - use crate::{generic_test_op, generic_test_op_new}; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - use datafusion_common::ScalarValue::Decimal128; - - #[test] - fn min_decimal() -> Result<()> { - // min - let left = ScalarValue::Decimal128(Some(123), 10, 2); - let right = ScalarValue::Decimal128(Some(124), 10, 2); - let result = min(&left, &right)?; - assert_eq!(result, left); - - // min batch - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - - let result = min_batch(&array)?; - assert_eq!(result, ScalarValue::Decimal128(Some(1), 10, 0)); - - // min batch without values - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(0) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - let result = min_batch(&array)?; - assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); - - // min batch with agg - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Min, - ScalarValue::Decimal128(Some(1), 10, 0) - ) - } - - #[test] - fn min_decimal_all_nulls() -> Result<()> { - // min batch all nulls - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(6) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Min, - ScalarValue::Decimal128(None, 10, 0) - ) - } - - #[test] - fn min_decimal_with_nulls() -> Result<()> { - // min batch with nulls - let array: ArrayRef = Arc::new( - (1..6) - .map(|i| if i == 2 { None } else { Some(i) }) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Min, - ScalarValue::Decimal128(Some(1), 10, 0) - ) - } - - #[test] - fn max_decimal() -> Result<()> { - // max - let left = ScalarValue::Decimal128(Some(123), 10, 2); - let right = ScalarValue::Decimal128(Some(124), 10, 2); - let result = max(&left, &right)?; - assert_eq!(result, right); - - let right = ScalarValue::Decimal128(Some(124), 10, 3); - let result = max(&left, &right); - let err_msg = format!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - (Decimal128(Some(123), 10, 2), Decimal128(Some(124), 10, 3)) - ); - let expect = DataFusionError::Internal(err_msg); - assert!(expect - .strip_backtrace() - .starts_with(&result.unwrap_err().strip_backtrace())); - - // max batch - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 5)?, - ); - let result = max_batch(&array)?; - assert_eq!(result, ScalarValue::Decimal128(Some(5), 10, 5)); - - // max batch without values - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(0) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - let result = max_batch(&array)?; - assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); - - // max batch with agg - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Max, - ScalarValue::Decimal128(Some(5), 10, 0) - ) - } - - #[test] - fn max_decimal_with_nulls() -> Result<()> { - let array: ArrayRef = Arc::new( - (1..6) - .map(|i| if i == 2 { None } else { Some(i) }) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Max, - ScalarValue::Decimal128(Some(5), 10, 0) - ) - } - - #[test] - fn max_decimal_all_nulls() -> Result<()> { - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(6) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Min, - ScalarValue::Decimal128(None, 10, 0) - ) - } - - #[test] - fn max_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, Max, ScalarValue::from(5i32)) - } - - #[test] - fn min_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, Min, ScalarValue::from(1i32)) - } - - #[test] - fn max_utf8() -> Result<()> { - let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!(a, DataType::Utf8, Max, ScalarValue::from("d")) - } - - #[test] - fn max_large_utf8() -> Result<()> { - let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::LargeUtf8, - Max, - ScalarValue::LargeUtf8(Some("d".to_string())) - ) - } - - #[test] - fn min_utf8() -> Result<()> { - let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!(a, DataType::Utf8, Min, ScalarValue::from("a")) - } - - #[test] - fn min_large_utf8() -> Result<()> { - let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::LargeUtf8, - Min, - ScalarValue::LargeUtf8(Some("a".to_string())) - ) - } - - #[test] - fn max_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - generic_test_op!(a, DataType::Int32, Max, ScalarValue::from(5i32)) - } - - #[test] - fn min_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - generic_test_op!(a, DataType::Int32, Min, ScalarValue::from(1i32)) - } - - #[test] - fn max_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, Max, ScalarValue::Int32(None)) - } - - #[test] - fn min_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, Min, ScalarValue::Int32(None)) - } - - #[test] - fn max_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!(a, DataType::UInt32, Max, ScalarValue::from(5_u32)) - } - - #[test] - fn min_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!(a, DataType::UInt32, Min, ScalarValue::from(1u32)) - } - - #[test] - fn max_f32() -> Result<()> { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!(a, DataType::Float32, Max, ScalarValue::from(5_f32)) - } - - #[test] - fn min_f32() -> Result<()> { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!(a, DataType::Float32, Min, ScalarValue::from(1_f32)) - } - - #[test] - fn max_f64() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Max, ScalarValue::from(5_f64)) - } - - #[test] - fn min_f64() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Min, ScalarValue::from(1_f64)) - } - - #[test] - fn min_date32() -> Result<()> { - let a: ArrayRef = Arc::new(Date32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Date32, Min, ScalarValue::Date32(Some(1))) - } - - #[test] - fn min_date64() -> Result<()> { - let a: ArrayRef = Arc::new(Date64Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Date64, Min, ScalarValue::Date64(Some(1))) - } - - #[test] - fn max_date32() -> Result<()> { - let a: ArrayRef = Arc::new(Date32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Date32, Max, ScalarValue::Date32(Some(5))) - } - - #[test] - fn max_date64() -> Result<()> { - let a: ArrayRef = Arc::new(Date64Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Date64, Max, ScalarValue::Date64(Some(5))) - } - - #[test] - fn min_time32second() -> Result<()> { - let a: ArrayRef = Arc::new(Time32SecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time32(TimeUnit::Second), - Min, - ScalarValue::Time32Second(Some(1)) - ) - } - - #[test] - fn max_time32second() -> Result<()> { - let a: ArrayRef = Arc::new(Time32SecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time32(TimeUnit::Second), - Max, - ScalarValue::Time32Second(Some(5)) - ) - } - - #[test] - fn min_time32millisecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time32MillisecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time32(TimeUnit::Millisecond), - Min, - ScalarValue::Time32Millisecond(Some(1)) - ) - } - - #[test] - fn max_time32millisecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time32MillisecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time32(TimeUnit::Millisecond), - Max, - ScalarValue::Time32Millisecond(Some(5)) - ) - } - - #[test] - fn min_time64microsecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time64MicrosecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time64(TimeUnit::Microsecond), - Min, - ScalarValue::Time64Microsecond(Some(1)) - ) - } - - #[test] - fn max_time64microsecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time64MicrosecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time64(TimeUnit::Microsecond), - Max, - ScalarValue::Time64Microsecond(Some(5)) - ) - } - - #[test] - fn min_time64nanosecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time64NanosecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time64(TimeUnit::Nanosecond), - Min, - ScalarValue::Time64Nanosecond(Some(1)) - ) - } - - #[test] - fn max_time64nanosecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time64NanosecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time64(TimeUnit::Nanosecond), - Max, - ScalarValue::Time64Nanosecond(Some(5)) - ) - } - - #[test] - fn max_new_timestamp_micro() -> Result<()> { - let dt = DataType::Timestamp(TimeUnit::Microsecond, None); - let actual = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]) - .with_data_type(dt.clone()); - let expected: ArrayRef = - Arc::new(TimestampMicrosecondArray::from(vec![5]).with_data_type(dt.clone())); - generic_test_op_new!(Arc::new(actual), dt.clone(), Max, &expected) - } - - #[test] - fn max_new_timestamp_micro_with_tz() -> Result<()> { - let dt = DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())); - let actual = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]) - .with_data_type(dt.clone()); - let expected: ArrayRef = - Arc::new(TimestampMicrosecondArray::from(vec![5]).with_data_type(dt.clone())); - generic_test_op_new!(Arc::new(actual), dt.clone(), Max, &expected) - } - - #[test] - fn max_bool() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![false, false])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(false))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, true])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(true))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![false, true, false])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(true))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, false, true])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(true))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(Vec::::new())); - generic_test_op!( - a, - DataType::Boolean, - Max, - ScalarValue::from(None as Option) - )?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![None as Option])); - generic_test_op!( - a, - DataType::Boolean, - Max, - ScalarValue::from(None as Option) - )?; - - let a: ArrayRef = - Arc::new(BooleanArray::from(vec![None, Some(true), Some(false)])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(true))?; - - Ok(()) - } - - #[test] - fn min_bool() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![false, false])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(false))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, true])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(true))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![false, true, false])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(false))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, false, true])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(false))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(Vec::::new())); - generic_test_op!( - a, - DataType::Boolean, - Min, - ScalarValue::from(None as Option) - )?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![None as Option])); - generic_test_op!( - a, - DataType::Boolean, - Min, - ScalarValue::from(None as Option) - )?; - - let a: ArrayRef = - Arc::new(BooleanArray::from(vec![None, Some(true), Some(false)])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(false))?; - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 3ce641c5aa46..7faf2caae01c 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -70,7 +70,6 @@ pub fn add_offset_to_expr( #[cfg(test)] mod tests { - use super::*; use crate::expressions::col; use crate::PhysicalSortExpr; @@ -147,7 +146,7 @@ mod tests { let col_f = &col("f", &test_schema)?; let col_g = &col("g", &test_schema)?; let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - eq_properties.add_equal_conditions(col_a, col_c); + eq_properties.add_equal_conditions(col_a, col_c)?; let option_asc = SortOptions { descending: false, @@ -204,7 +203,7 @@ mod tests { let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f); + eq_properties.add_equal_conditions(col_a, col_f)?; // Column e has constant value. eq_properties = eq_properties.add_constants([col_e.clone()]); @@ -338,11 +337,11 @@ mod tests { let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; // a and b are aliases - eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr); + eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); // This new entry is redundant, size shouldn't increase - eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr); + eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = &eq_properties.eq_group().classes[0]; assert_eq!(eq_groups.len(), 2); @@ -351,7 +350,7 @@ mod tests { // b and c are aliases. Exising equivalence class should expand, // however there shouldn't be any new equivalence class - eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr); + eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = &eq_properties.eq_group().classes[0]; assert_eq!(eq_groups.len(), 3); @@ -360,12 +359,12 @@ mod tests { assert!(eq_groups.contains(&col_c_expr)); // This is a new set of equality. Hence equivalent class count should be 2. - eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr); + eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?; assert_eq!(eq_properties.eq_group().len(), 2); // This equality bridges distinct equality sets. // Hence equivalent class count should decrease from 2 to 1. - eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr); + eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = &eq_properties.eq_group().classes[0]; assert_eq!(eq_groups.len(), 5); diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index ed4600f2d95e..7857d9df726e 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -223,26 +223,26 @@ fn resolve_overlap(orderings: &mut [LexOrdering], idx: usize, pre_idx: usize) -> mod tests { use std::sync::Arc; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::SortOptions; - use itertools::Itertools; - - use datafusion_common::{DFSchema, Result}; - use datafusion_expr::{Operator, ScalarUDF}; - use crate::equivalence::tests::{ convert_to_orderings, convert_to_sort_exprs, create_random_schema, - create_test_params, generate_table_for_eq_properties, is_table_same_after_sort, + create_test_params, create_test_schema, generate_table_for_eq_properties, + is_table_same_after_sort, }; - use crate::equivalence::{tests::create_test_schema, EquivalenceProperties}; use crate::equivalence::{ - EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, + EquivalenceClass, EquivalenceGroup, EquivalenceProperties, + OrderingEquivalenceClass, }; - use crate::expressions::Column; - use crate::expressions::{col, BinaryExpr}; + use crate::expressions::{col, BinaryExpr, Column}; use crate::utils::tests::TestScalarUDF; use crate::{PhysicalExpr, PhysicalSortExpr}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::SortOptions; + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::{Operator, ScalarUDF}; + + use itertools::Itertools; + #[test] fn test_ordering_satisfy() -> Result<()> { let input_schema = Arc::new(Schema::new(vec![ @@ -883,7 +883,7 @@ mod tests { }; // a=c (e.g they are aliases). let mut eq_properties = EquivalenceProperties::new(test_schema); - eq_properties.add_equal_conditions(col_a, col_c); + eq_properties.add_equal_conditions(col_a, col_c)?; let orderings = vec![ vec![(col_a, options)], diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 260610f23dc6..b5ac149d8b71 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -17,14 +17,13 @@ use std::sync::Arc; -use arrow::datatypes::SchemaRef; +use crate::expressions::Column; +use crate::PhysicalExpr; +use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, Result}; -use crate::expressions::Column; -use crate::PhysicalExpr; - /// Stores the mapping between source expressions and target expressions for a /// projection. #[derive(Debug, Clone)] @@ -114,14 +113,7 @@ impl ProjectionMapping { #[cfg(test)] mod tests { - - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::{SortOptions, TimeUnit}; - use itertools::Itertools; - - use datafusion_common::DFSchema; - use datafusion_expr::{Operator, ScalarUDF}; - + use super::*; use crate::equivalence::tests::{ apply_projection, convert_to_orderings, convert_to_orderings_owned, create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, @@ -133,7 +125,12 @@ mod tests { use crate::utils::tests::TestScalarUDF; use crate::PhysicalSortExpr; - use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{SortOptions, TimeUnit}; + use datafusion_common::DFSchema; + use datafusion_expr::{Operator, ScalarUDF}; + + use itertools::Itertools; #[test] fn project_orderings() -> Result<()> { @@ -941,7 +938,7 @@ mod tests { for (orderings, equal_columns, expected) in test_cases { let mut eq_properties = EquivalenceProperties::new(schema.clone()); for (lhs, rhs) in equal_columns { - eq_properties.add_equal_conditions(lhs, rhs); + eq_properties.add_equal_conditions(lhs, rhs)?; } let orderings = convert_to_orderings(&orderings); diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index c654208208df..016c4c4ae107 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -18,25 +18,27 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; -use arrow_schema::{SchemaRef, SortOptions}; -use indexmap::{IndexMap, IndexSet}; -use itertools::Itertools; - -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{JoinSide, JoinType, Result}; - +use super::ordering::collapse_lex_ordering; use crate::equivalence::{ collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; use crate::expressions::{CastExpr, Literal}; -use crate::sort_properties::{ExprOrdering, SortProperties}; use crate::{ physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use super::ordering::collapse_lex_ordering; +use arrow_schema::{SchemaRef, SortOptions}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{JoinSide, JoinType, Result}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_physical_expr_common::expressions::column::Column; +use datafusion_physical_expr_common::utils::ExprPropertiesNode; + +use indexmap::{IndexMap, IndexSet}; +use itertools::Itertools; /// A `EquivalenceProperties` object stores useful information related to a schema. /// Currently, it keeps track of: @@ -197,7 +199,7 @@ impl EquivalenceProperties { &mut self, left: &Arc, right: &Arc, - ) { + ) -> Result<()> { // Discover new constants in light of new the equality: if self.is_expr_constant(left) { // Left expression is constant, add right as constant @@ -216,27 +218,34 @@ impl EquivalenceProperties { let mut new_orderings = vec![]; for ordering in self.normalized_oeq_class().iter() { let expressions = if left.eq(&ordering[0].expr) { - // left expression is leading ordering + // Left expression is leading ordering Some((ordering[0].options, right)) } else if right.eq(&ordering[0].expr) { - // right expression is leading ordering + // Right expression is leading ordering Some((ordering[0].options, left)) } else { None }; if let Some((leading_ordering, other_expr)) = expressions { - // Only handle expressions with exactly one child - // TODO: it should be possible to handle expressions orderings f(a, b, c), a, b, c - // if f is monotonic in all arguments - // First Expression after leading ordering + // Currently, we only handle expressions with a single child. + // TODO: It should be possible to handle expressions orderings like + // f(a, b, c), a, b, c if f is monotonic in all arguments. + // First expression after leading ordering if let Some(next_expr) = ordering.get(1) { let children = other_expr.children(); if children.len() == 1 && children[0].eq(&next_expr.expr) && SortProperties::Ordered(leading_ordering) - == other_expr.get_ordering(&[SortProperties::Ordered( - next_expr.options, - )]) + == other_expr + .get_properties(&[ExprProperties { + sort_properties: SortProperties::Ordered( + leading_ordering, + ), + range: Interval::make_unbounded( + &other_expr.data_type(&self.schema)?, + )?, + }])? + .sort_properties { // Assume existing ordering is [a ASC, b ASC] // When equality a = f(b) is given, If we know that given ordering `[b ASC]`, ordering `[f(b) ASC]` is valid, @@ -254,6 +263,7 @@ impl EquivalenceProperties { // Add equal expressions to the state self.eq_group.add_equal_conditions(left, right); + Ok(()) } /// Track/register physical expressions with constant values. @@ -378,11 +388,15 @@ impl EquivalenceProperties { /// /// Returns `true` if the specified ordering is satisfied, `false` otherwise. fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { - let expr_ordering = self.get_expr_ordering(req.expr.clone()); - let ExprOrdering { expr, data, .. } = expr_ordering; - match data { + let ExprProperties { + sort_properties, .. + } = self.get_expr_properties(req.expr.clone()); + match sort_properties { SortProperties::Ordered(options) => { - let sort_expr = PhysicalSortExpr { expr, options }; + let sort_expr = PhysicalSortExpr { + expr: req.expr.clone(), + options, + }; sort_expr.satisfy(req, self.schema()) } // Singleton expressions satisfies any ordering. @@ -698,8 +712,9 @@ impl EquivalenceProperties { referred_dependencies(&dependency_map, source) .into_iter() .filter_map(|relevant_deps| { - if let SortProperties::Ordered(options) = - get_expr_ordering(source, &relevant_deps) + if let Ok(SortProperties::Ordered(options)) = + get_expr_properties(source, &relevant_deps, &self.schema) + .map(|prop| prop.sort_properties) { Some((options, relevant_deps)) } else { @@ -837,16 +852,27 @@ impl EquivalenceProperties { let ordered_exprs = search_indices .iter() .flat_map(|&idx| { - let ExprOrdering { expr, data, .. } = - eq_properties.get_expr_ordering(exprs[idx].clone()); - match data { - SortProperties::Ordered(options) => { - Some((PhysicalSortExpr { expr, options }, idx)) - } + let ExprProperties { + sort_properties, .. + } = eq_properties.get_expr_properties(exprs[idx].clone()); + match sort_properties { + SortProperties::Ordered(options) => Some(( + PhysicalSortExpr { + expr: exprs[idx].clone(), + options, + }, + idx, + )), SortProperties::Singleton => { // Assign default ordering to constant expressions let options = SortOptions::default(); - Some((PhysicalSortExpr { expr, options }, idx)) + Some(( + PhysicalSortExpr { + expr: exprs[idx].clone(), + options, + }, + idx, + )) } SortProperties::Unordered => None, } @@ -895,32 +921,33 @@ impl EquivalenceProperties { is_constant_recurse(&normalized_constants, &normalized_expr) } - /// Retrieves the ordering information for a given physical expression. + /// Retrieves the properties for a given physical expression. /// - /// This function constructs an `ExprOrdering` object for the provided + /// This function constructs an [`ExprProperties`] object for the given /// expression, which encapsulates information about the expression's - /// ordering, including its [`SortProperties`]. + /// properties, including its [`SortProperties`] and [`Interval`]. /// - /// # Arguments + /// # Parameters /// /// - `expr`: An `Arc` representing the physical expression /// for which ordering information is sought. /// /// # Returns /// - /// Returns an `ExprOrdering` object containing the ordering information for - /// the given expression. - pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { - ExprOrdering::new_default(expr.clone()) - .transform_up(|expr| Ok(update_ordering(expr, self))) + /// Returns an [`ExprProperties`] object containing the ordering and range + /// information for the given expression. + pub fn get_expr_properties(&self, expr: Arc) -> ExprProperties { + ExprPropertiesNode::new_unknown(expr) + .transform_up(|expr| update_properties(expr, self)) .data() - // Guaranteed to always return `Ok`. - .unwrap() + .map(|node| node.data) + .unwrap_or(ExprProperties::new_unknown()) } } -/// Calculates the [`SortProperties`] of a given [`ExprOrdering`] node. -/// The node can either be a leaf node, or an intermediate node: +/// Calculates the properties of a given [`ExprPropertiesNode`]. +/// +/// Order information can be retrieved as: /// - If it is a leaf node, we directly find the order of the node by looking /// at the given sort expression and equivalence properties if it is a `Column` /// leaf, or we mark it as unordered. In the case of a `Literal` leaf, we mark @@ -931,30 +958,41 @@ impl EquivalenceProperties { /// node directly matches with the sort expression. If there is a match, the /// sort expression emerges at that node immediately, discarding the recursive /// result coming from its children. -fn update_ordering( - mut node: ExprOrdering, +/// +/// Range information is calculated as: +/// - If it is a `Literal` node, we set the range as a point value. If it is a +/// `Column` node, we set the datatype of the range, but cannot give an interval +/// for the range, yet. +/// - If it is an intermediate node, the children states matter. Each `PhysicalExpr` +/// and operator has its own rules on how to propagate the children range. +fn update_properties( + mut node: ExprPropertiesNode, eq_properties: &EquivalenceProperties, -) -> Transformed { - // We have a Column, which is one of the two possible leaf node types: +) -> Result> { + // First, try to gather the information from the children: + if !node.expr.children().is_empty() { + // We have an intermediate (non-leaf) node, account for its children: + let children_props = node.children.iter().map(|c| c.data.clone()).collect_vec(); + node.data = node.expr.get_properties(&children_props)?; + } else if node.expr.as_any().is::() { + // We have a Literal, which is one of the two possible leaf node types: + node.data = node.expr.get_properties(&[])?; + } else if node.expr.as_any().is::() { + // We have a Column, which is the other possible leaf node type: + node.data.range = + Interval::make_unbounded(&node.expr.data_type(eq_properties.schema())?)? + } + // Now, check what we know about orderings: let normalized_expr = eq_properties.eq_group.normalize_expr(node.expr.clone()); if eq_properties.is_expr_constant(&normalized_expr) { - node.data = SortProperties::Singleton; + node.data.sort_properties = SortProperties::Singleton; } else if let Some(options) = eq_properties .normalized_oeq_class() .get_options(&normalized_expr) { - node.data = SortProperties::Ordered(options); - } else if !node.expr.children().is_empty() { - // We have an intermediate (non-leaf) node, account for its children: - let children_orderings = node.children.iter().map(|c| c.data).collect_vec(); - node.data = node.expr.get_ordering(&children_orderings); - } else if node.expr.as_any().is::() { - // We have a Literal, which is the other possible leaf node type: - node.data = node.expr.get_ordering(&[]); - } else { - return Transformed::no(node); + node.data.sort_properties = SortProperties::Ordered(options); } - Transformed::yes(node) + Ok(Transformed::yes(node)) } /// This function determines whether the provided expression is constant @@ -1124,8 +1162,9 @@ fn generate_dependency_orderings( .collect() } -/// This function examines the given expression and the sort expressions it -/// refers to determine the ordering properties of the expression. +/// This function examines the given expression and its properties to determine +/// the ordering properties of the expression. The range knowledge is not utilized +/// yet in the scope of this function. /// /// # Parameters /// @@ -1133,26 +1172,41 @@ fn generate_dependency_orderings( /// which ordering properties need to be determined. /// - `dependencies`: A reference to `Dependencies`, containing sort expressions /// referred to by `expr`. +/// - `schema``: A reference to the schema which the `expr` columns refer. /// /// # Returns /// /// A `SortProperties` indicating the ordering information of the given expression. -fn get_expr_ordering( +fn get_expr_properties( expr: &Arc, dependencies: &Dependencies, -) -> SortProperties { + schema: &SchemaRef, +) -> Result { if let Some(column_order) = dependencies.iter().find(|&order| expr.eq(&order.expr)) { // If exact match is found, return its ordering. - SortProperties::Ordered(column_order.options) + Ok(ExprProperties { + sort_properties: SortProperties::Ordered(column_order.options), + range: Interval::make_unbounded(&expr.data_type(schema)?)?, + }) + } else if expr.as_any().downcast_ref::().is_some() { + Ok(ExprProperties { + sort_properties: SortProperties::Unordered, + range: Interval::make_unbounded(&expr.data_type(schema)?)?, + }) + } else if let Some(literal) = expr.as_any().downcast_ref::() { + Ok(ExprProperties { + sort_properties: SortProperties::Singleton, + range: Interval::try_new(literal.value().clone(), literal.value().clone())?, + }) } else { // Find orderings of its children let child_states = expr .children() .iter() - .map(|child| get_expr_ordering(child, dependencies)) - .collect::>(); + .map(|child| get_expr_properties(child, dependencies, schema)) + .collect::>>()?; // Calculate expression ordering using ordering of its children. - expr.get_ordering(&child_states) + expr.get_properties(&child_states) } } @@ -1351,12 +1405,7 @@ impl Hash for ExprWrapper { mod tests { use std::ops::Not; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::{Fields, TimeUnit}; - - use datafusion_common::DFSchema; - use datafusion_expr::{Operator, ScalarUDF}; - + use super::*; use crate::equivalence::add_offset_to_expr; use crate::equivalence::tests::{ convert_to_orderings, convert_to_sort_exprs, convert_to_sort_reqs, @@ -1366,7 +1415,10 @@ mod tests { use crate::expressions::{col, BinaryExpr, Column}; use crate::utils::tests::TestScalarUDF; - use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{Fields, TimeUnit}; + use datafusion_common::DFSchema; + use datafusion_expr::{Operator, ScalarUDF}; #[test] fn project_equivalence_properties_test() -> Result<()> { @@ -1577,8 +1629,8 @@ mod tests { let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema)); // a=x and d=w - join_eq_properties.add_equal_conditions(col_a, col_x); - join_eq_properties.add_equal_conditions(col_d, col_w); + join_eq_properties.add_equal_conditions(col_a, col_x)?; + join_eq_properties.add_equal_conditions(col_d, col_w)?; updated_right_ordering_equivalence_class( &mut right_oeq_class, @@ -1615,7 +1667,7 @@ mod tests { let col_c_expr = col("c", &schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr); + eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr)?; let others = vec![ vec![PhysicalSortExpr { expr: col_b_expr.clone(), @@ -1760,7 +1812,7 @@ mod tests { } #[test] - fn test_update_ordering() -> Result<()> { + fn test_update_properties() -> Result<()> { let schema = Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), @@ -1778,7 +1830,7 @@ mod tests { nulls_first: false, }; // b=a (e.g they are aliases) - eq_properties.add_equal_conditions(col_b, col_a); + eq_properties.add_equal_conditions(col_b, col_a)?; // [b ASC], [d ASC] eq_properties.add_new_orderings(vec![ vec![PhysicalSortExpr { @@ -1821,12 +1873,12 @@ mod tests { .iter() .flat_map(|ordering| ordering.first().cloned()) .collect::>(); - let expr_ordering = eq_properties.get_expr_ordering(expr.clone()); + let expr_props = eq_properties.get_expr_properties(expr.clone()); let err_msg = format!( "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", - expr, expected, expr_ordering.data + expr, expected, expr_props.sort_properties ); - assert_eq!(expr_ordering.data, expected, "{}", err_msg); + assert_eq!(expr_props.sort_properties, expected, "{}", err_msg); } Ok(()) @@ -2266,6 +2318,7 @@ mod tests { Ok(()) } + #[test] fn test_eliminate_redundant_monotonic_sorts() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -2334,7 +2387,7 @@ mod tests { for case in cases { let mut properties = base_properties.clone().add_constants(case.constants); for [left, right] in &case.equal_conditions { - properties.add_equal_conditions(left, right) + properties.add_equal_conditions(left, right)? } let sort = case diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 76154dca0338..08f7523f92f0 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -23,21 +23,21 @@ use std::{any::Any, sync::Arc}; use crate::expressions::datum::{apply, apply_cmp}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::physical_expr::down_cast_any_ref; -use crate::sort_properties::SortProperties; use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; use arrow::compute::kernels::cmp::*; -use arrow::compute::kernels::comparison::regexp_is_match_utf8; -use arrow::compute::kernels::comparison::regexp_is_match_utf8_scalar; +use arrow::compute::kernels::comparison::{ + regexp_is_match_utf8, regexp_is_match_utf8_scalar, +}; use arrow::compute::kernels::concat_elements::concat_elements_utf8; use arrow::compute::{cast, ilike, like, nilike, nlike}; use arrow::datatypes::*; - use datafusion_common::cast::as_boolean_array; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; +use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::binary::get_result_type; use datafusion_expr::{ColumnarValue, Operator}; @@ -442,17 +442,45 @@ impl PhysicalExpr for BinaryExpr { self.hash(&mut s); } - /// For each operator, [`BinaryExpr`] has distinct ordering rules. - /// TODO: There may be rules specific to some data types (such as division and multiplication on unsigned integers) - fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { - let (left_child, right_child) = (&children[0], &children[1]); + /// For each operator, [`BinaryExpr`] has distinct rules. + /// TODO: There may be rules specific to some data types and expression ranges. + fn get_properties(&self, children: &[ExprProperties]) -> Result { + let (l_order, l_range) = (children[0].sort_properties, &children[0].range); + let (r_order, r_range) = (children[1].sort_properties, &children[1].range); match self.op() { - Operator::Plus => left_child.add(right_child), - Operator::Minus => left_child.sub(right_child), - Operator::Gt | Operator::GtEq => left_child.gt_or_gteq(right_child), - Operator::Lt | Operator::LtEq => right_child.gt_or_gteq(left_child), - Operator::And | Operator::Or => left_child.and_or(right_child), - _ => SortProperties::Unordered, + Operator::Plus => Ok(ExprProperties { + sort_properties: l_order.add(&r_order), + range: l_range.add(r_range)?, + }), + Operator::Minus => Ok(ExprProperties { + sort_properties: l_order.sub(&r_order), + range: l_range.sub(r_range)?, + }), + Operator::Gt => Ok(ExprProperties { + sort_properties: l_order.gt_or_gteq(&r_order), + range: l_range.gt(r_range)?, + }), + Operator::GtEq => Ok(ExprProperties { + sort_properties: l_order.gt_or_gteq(&r_order), + range: l_range.gt_eq(r_range)?, + }), + Operator::Lt => Ok(ExprProperties { + sort_properties: r_order.gt_or_gteq(&l_order), + range: l_range.lt(r_range)?, + }), + Operator::LtEq => Ok(ExprProperties { + sort_properties: r_order.gt_or_gteq(&l_order), + range: l_range.lt_eq(r_range)?, + }), + Operator::And => Ok(ExprProperties { + sort_properties: r_order.and_or(&l_order), + range: l_range.and(r_range)?, + }), + Operator::Or => Ok(ExprProperties { + sort_properties: r_order.and_or(&l_order), + range: l_range.or(r_range)?, + }), + _ => Ok(ExprProperties::new_unknown()), } } } @@ -623,6 +651,7 @@ pub fn binary( mod tests { use super::*; use crate::expressions::{col, lit, try_cast, Literal}; + use datafusion_common::plan_datafusion_err; use datafusion_expr::type_coercion::binary::get_input_types; diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index a3b32461e581..79a44ac30cfc 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -15,21 +15,21 @@ // specific language governing permissions and limitations // under the License. -use crate::physical_expr::down_cast_any_ref; -use crate::sort_properties::SortProperties; -use crate::PhysicalExpr; use std::any::Any; use std::fmt; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use DataType::*; + +use crate::physical_expr::down_cast_any_ref; +use crate::PhysicalExpr; use arrow::compute::{can_cast_types, CastOptions}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, DataType::*, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result}; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::ColumnarValue; const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { @@ -163,9 +163,21 @@ impl PhysicalExpr for CastExpr { self.cast_options.hash(&mut s); } - /// A [`CastExpr`] preserves the ordering of its child. - fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { - children[0] + /// A [`CastExpr`] preserves the ordering of its child if the cast is done + /// under the same datatype family. + fn get_properties(&self, children: &[ExprProperties]) -> Result { + let source_datatype = children[0].range.data_type(); + let target_type = &self.cast_type; + + let unbounded = Interval::make_unbounded(target_type)?; + if source_datatype.is_numeric() && target_type.is_numeric() + || source_datatype.is_temporal() && target_type.is_temporal() + || source_datatype.eq(target_type) + { + Ok(children[0].clone().with_range(unbounded)) + } else { + Ok(ExprProperties::new_unknown().with_range(unbounded)) + } } } diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 35ea80ea574d..371028959ab8 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -22,7 +22,6 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::down_cast_any_ref; -use crate::sort_properties::SortProperties; use crate::PhysicalExpr; use arrow::{ @@ -30,6 +29,8 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ColumnarValue, Expr}; /// Represents a literal value @@ -90,8 +91,11 @@ impl PhysicalExpr for Literal { self.hash(&mut s); } - fn get_ordering(&self, _children: &[SortProperties]) -> SortProperties { - SortProperties::Singleton + fn get_properties(&self, _children: &[ExprProperties]) -> Result { + Ok(ExprProperties { + sort_properties: SortProperties::Singleton, + range: Interval::try_new(self.value().clone(), self.value().clone())?, + }) } } @@ -115,6 +119,7 @@ pub fn lit(value: T) -> Arc { #[cfg(test)] mod tests { use super::*; + use arrow::array::Int32Array; use arrow::datatypes::*; use datafusion_common::cast::as_int32_array; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index c16b609e2375..980297b8b433 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -100,9 +100,7 @@ pub(crate) mod tests { use crate::AggregateExpr; use arrow::record_batch::RecordBatch; - use arrow_array::ArrayRef; use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::EmitTo; /// macro to perform an aggregation using [`datafusion_expr::Accumulator`] and verify the /// result. @@ -250,22 +248,4 @@ pub(crate) mod tests { accum.update_batch(&values)?; accum.evaluate() } - - pub fn aggregate_new( - batch: &RecordBatch, - agg: Arc, - ) -> Result { - let mut accum = agg.create_groups_accumulator()?; - let expr = agg.expressions(); - let values = expr - .iter() - .map(|e| { - e.evaluate(batch) - .and_then(|v| v.into_array(batch.num_rows())) - }) - .collect::>>()?; - let indices = vec![0; batch.num_rows()]; - accum.update_batch(&values, &indices, None, 1)?; - accum.evaluate(EmitTo::All) - } } diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index f6d4620c427f..62f865bd9b32 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -22,7 +22,6 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::down_cast_any_ref; -use crate::sort_properties::SortProperties; use crate::PhysicalExpr; use arrow::{ @@ -32,6 +31,7 @@ use arrow::{ }; use datafusion_common::{plan_err, Result}; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::{ type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp}, ColumnarValue, @@ -134,8 +134,11 @@ impl PhysicalExpr for NegativeExpr { } /// The ordering of a [`NegativeExpr`] is simply the reverse of its child. - fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { - -children[0] + fn get_properties(&self, children: &[ExprProperties]) -> Result { + Ok(ExprProperties { + sort_properties: -children[0].sort_properties, + range: children[0].range.clone().arithmetic_negate()?, + }) } } diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 21cf6d348cd5..9c7d6d09349d 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -33,14 +33,12 @@ use std::sync::Arc; -use arrow::array::ArrayRef; -use arrow_array::Array; - -pub use crate::scalar_function::create_physical_expr; +use arrow::array::{Array, ArrayRef}; use datafusion_common::{Result, ScalarValue}; -pub use datafusion_expr::FuncMonotonicity; use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +pub use crate::scalar_function::create_physical_expr; + #[derive(Debug, Clone, Copy)] pub enum Hint { /// Indicates the argument needs to be padded if it is scalar diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index aef5aa7c00e7..1bdf082b2eaf 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -61,13 +61,6 @@ pub use scalar_function::ScalarFunctionExpr; pub use datafusion_physical_expr_common::utils::reverse_order_bys; pub use utils::split_conjunction; -// For backwards compatibility -pub mod sort_properties { - pub use datafusion_physical_expr_common::sort_properties::{ - ExprOrdering, SortProperties, - }; -} - // For backwards compatibility pub mod tree_node { pub use datafusion_physical_expr_common::tree_node::ExprContext; diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 1244a9b4db38..daa110071096 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -32,19 +32,18 @@ use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; -use std::ops::Neg; use std::sync::Arc; +use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal}; +use crate::PhysicalExpr; + use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; - use datafusion_common::{internal_err, DFSchema, Result}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; -use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, FuncMonotonicity, ScalarUDF}; - -use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal}; -use crate::sort_properties::SortProperties; -use crate::PhysicalExpr; +use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, ScalarUDF}; /// Physical expression of a scalar function pub struct ScalarFunctionExpr { @@ -52,11 +51,6 @@ pub struct ScalarFunctionExpr { name: String, args: Vec>, return_type: DataType, - // Keeps monotonicity information of the function. - // FuncMonotonicity vector is one to one mapped to `args`, - // and it specifies the effect of an increase or decrease in - // the corresponding `arg` to the function value. - monotonicity: Option, } impl Debug for ScalarFunctionExpr { @@ -66,7 +60,6 @@ impl Debug for ScalarFunctionExpr { .field("name", &self.name) .field("args", &self.args) .field("return_type", &self.return_type) - .field("monotonicity", &self.monotonicity) .finish() } } @@ -78,14 +71,12 @@ impl ScalarFunctionExpr { fun: Arc, args: Vec>, return_type: DataType, - monotonicity: Option, ) -> Self { Self { fun, name: name.to_owned(), args, return_type, - monotonicity, } } @@ -108,11 +99,6 @@ impl ScalarFunctionExpr { pub fn return_type(&self) -> &DataType { &self.return_type } - - /// Monotonicity information of the function - pub fn monotonicity(&self) -> &Option { - &self.monotonicity - } } impl fmt::Display for ScalarFunctionExpr { @@ -170,10 +156,21 @@ impl PhysicalExpr for ScalarFunctionExpr { self.fun.clone(), children, self.return_type().clone(), - self.monotonicity.clone(), ))) } + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + self.fun.evaluate_bounds(children) + } + + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result>> { + self.fun.propagate_constraints(interval, children) + } + fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.name.hash(&mut s); @@ -182,11 +179,18 @@ impl PhysicalExpr for ScalarFunctionExpr { // Add `self.fun` when hash is available } - fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { - self.monotonicity - .as_ref() - .map(|monotonicity| out_ordering(monotonicity, children)) - .unwrap_or(SortProperties::Unordered) + fn get_properties(&self, children: &[ExprProperties]) -> Result { + let sort_properties = self.fun.monotonicity(children)?; + let children_range = children + .iter() + .map(|props| &props.range) + .collect::>(); + let range = self.fun().evaluate_bounds(&children_range)?; + + Ok(ExprProperties { + sort_properties, + range, + }) } } @@ -231,63 +235,5 @@ pub fn create_physical_expr( Arc::new(fun.clone()), input_phy_exprs.to_vec(), return_type, - fun.monotonicity()?, ))) } - -/// Determines a [ScalarFunctionExpr]'s monotonicity for the given arguments -/// and the function's behavior depending on its arguments. -/// -/// [ScalarFunctionExpr]: crate::scalar_function::ScalarFunctionExpr -pub fn out_ordering( - func: &FuncMonotonicity, - arg_orderings: &[SortProperties], -) -> SortProperties { - func.iter().zip(arg_orderings).fold( - SortProperties::Singleton, - |prev_sort, (item, arg)| { - let current_sort = func_order_in_one_dimension(item, arg); - - match (prev_sort, current_sort) { - (_, SortProperties::Unordered) => SortProperties::Unordered, - (SortProperties::Singleton, SortProperties::Ordered(_)) => current_sort, - (SortProperties::Ordered(prev), SortProperties::Ordered(current)) - if prev.descending != current.descending => - { - SortProperties::Unordered - } - _ => prev_sort, - } - }, - ) -} - -/// This function decides the monotonicity property of a [ScalarFunctionExpr] for a single argument (i.e. across a single dimension), given that argument's sort properties. -/// -/// [ScalarFunctionExpr]: crate::scalar_function::ScalarFunctionExpr -fn func_order_in_one_dimension( - func_monotonicity: &Option, - arg: &SortProperties, -) -> SortProperties { - if *arg == SortProperties::Singleton { - SortProperties::Singleton - } else { - match func_monotonicity { - None => SortProperties::Unordered, - Some(false) => { - if let SortProperties::Ordered(_) = arg { - arg.neg() - } else { - SortProperties::Unordered - } - } - Some(true) => { - if let SortProperties::Ordered(_) = arg { - *arg - } else { - SortProperties::Unordered - } - } - } - } -} diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 76cee3a1a786..6b964546cb74 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -255,19 +255,18 @@ pub fn merge_vectors( #[cfg(test)] pub(crate) mod tests { - use arrow_array::{ArrayRef, Float32Array, Float64Array}; use std::any::Any; use std::fmt::{Display, Formatter}; use super::*; use crate::expressions::{binary, cast, col, in_list, lit, Literal}; + use arrow_array::{ArrayRef, Float32Array, Float64Array}; use arrow_schema::{DataType, Field, Schema}; use datafusion_common::{exec_err, DataFusionError, ScalarValue}; + use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - use datafusion_expr::{ - ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, - }; use petgraph::visit::Bfs; #[derive(Debug, Clone)] @@ -309,8 +308,8 @@ pub(crate) mod tests { } } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![Some(true)])) + fn monotonicity(&self, input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) } fn invoke(&self, args: &[ColumnarValue]) -> Result { diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 95376e7e69cd..21608db40d56 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -2080,7 +2080,7 @@ mod tests { let col_c = &col("c", &test_schema)?; let mut eq_properties = EquivalenceProperties::new(test_schema); // Columns a and b are equal. - eq_properties.add_equal_conditions(col_a, col_b); + eq_properties.add_equal_conditions(col_a, col_b)?; // Aggregate requirements are // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively let order_by_exprs = vec![ diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 9e2216ae0a63..c61e9a05bfa6 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -153,16 +153,23 @@ pub fn compute_record_batch_statistics( }) .sum(); - let mut column_statistics = vec![ColumnStatistics::new_unknown(); projection.len()]; + let mut null_counts = vec![0; projection.len()]; for partition in batches.iter() { for batch in partition { for (stat_index, col_index) in projection.iter().enumerate() { - column_statistics[stat_index].null_count = - Precision::Exact(batch.column(*col_index).null_count()); + null_counts[stat_index] += batch.column(*col_index).null_count(); } } } + let column_statistics = null_counts + .into_iter() + .map(|null_count| { + let mut s = ColumnStatistics::new_unknown(); + s.null_count = Precision::Exact(null_count); + s + }) + .collect(); Statistics { num_rows: Precision::Exact(nb_rows), @@ -687,4 +694,35 @@ mod tests { assert_eq!(actual, expected); Ok(()) } + + #[test] + fn test_compute_record_batch_statistics_null() -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("u64", DataType::UInt64, true)])); + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(UInt64Array::from(vec![Some(1), None, None]))], + )?; + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(UInt64Array::from(vec![Some(1), Some(2), None]))], + )?; + let byte_size = batch1.get_array_memory_size() + batch2.get_array_memory_size(); + let actual = + compute_record_batch_statistics(&[vec![batch1], vec![batch2]], &schema, None); + + let expected = Statistics { + num_rows: Precision::Exact(6), + total_byte_size: Precision::Exact(byte_size), + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + null_count: Precision::Exact(3), + }], + }; + + assert_eq!(actual, expected); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index bf1ab8b73126..6729e3b9e603 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -192,7 +192,7 @@ impl FilterExec { let mut eq_properties = input.equivalence_properties().clone(); let (equal_pairs, _) = collect_columns_from_predicate(predicate); for (lhs, rhs) in equal_pairs { - eq_properties.add_equal_conditions(lhs, rhs) + eq_properties.add_equal_conditions(lhs, rhs)? } // Add the columns that have only one viable value (singleton) after // filtering to constants. @@ -433,13 +433,12 @@ pub type EqualAndNonEqual<'a> = #[cfg(test)] mod tests { - use super::*; + use crate::empty::EmptyExec; use crate::expressions::*; use crate::test; use crate::test::exec::StatisticsExec; - use crate::empty::EmptyExec; use arrow::datatypes::{Field, Schema}; use arrow_schema::{UnionFields, UnionMode}; use datafusion_common::ScalarValue; diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 4c928a3d2d8d..d4cf6864d7e4 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1323,7 +1323,9 @@ impl SMJStream { // If join filter exists, `self.output_size` is not accurate as we don't know the exact // number of rows in the output record batch. If streamed row joined with buffered rows, // once join filter is applied, the number of output rows may be more than 1. - if record_batch.num_rows() > self.output_size { + // If `record_batch` is empty, we should reset `self.output_size` to 0. It could be happened + // when the join filter is applied and all rows are filtered out. + if record_batch.num_rows() == 0 || record_batch.num_rows() > self.output_size { self.output_size = 0; } else { self.output_size -= record_batch.num_rows(); diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index ff60329ce179..42c630741cc9 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -42,6 +42,7 @@ use datafusion_physical_expr::{ window::{BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr}, AggregateExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, }; +use itertools::Itertools; mod bounded_window_agg_exec; mod window_agg_exec; @@ -52,6 +53,31 @@ pub use datafusion_physical_expr::window::{ }; pub use window_agg_exec::WindowAggExec; +/// Build field from window function and add it into schema +pub fn schema_add_window_field( + args: &[Arc], + schema: &Schema, + window_fn: &WindowFunctionDefinition, + fn_name: &str, +) -> Result> { + let data_types = args + .iter() + .map(|e| e.clone().as_ref().data_type(schema)) + .collect::>>()?; + let window_expr_return_type = window_fn.return_type(&data_types)?; + let mut window_fields = schema + .fields() + .iter() + .map(|f| f.as_ref().clone()) + .collect_vec(); + window_fields.extend_from_slice(&[Field::new( + fn_name, + window_expr_return_type, + false, + )]); + Ok(Arc::new(Schema::new(window_fields))) +} + /// Create a physical expression for window function #[allow(clippy::too_many_arguments)] pub fn create_window_expr( @@ -103,6 +129,7 @@ pub fn create_window_expr( input_schema, name, ignore_nulls, + false, )?; window_expr_from_aggregate_expr( partition_by, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index c907e991fb86..b7bc60a0486c 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -40,7 +40,7 @@ use datafusion::physical_plan::expressions::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, }; -use datafusion::physical_plan::windows::create_window_expr; +use datafusion::physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion::physical_plan::{ ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, }; @@ -155,14 +155,18 @@ pub fn parse_physical_window_expr( ) })?; + let fun: WindowFunctionDefinition = convert_required!(proto.window_function)?; + let name = proto.name.clone(); + let extended_schema = + schema_add_window_field(&window_node_expr, input_schema, &fun, &name)?; create_window_expr( - &convert_required!(proto.window_function)?, - proto.name.clone(), + &fun, + name, &window_node_expr, &partition_by, &order_by, Arc::new(window_frame), - input_schema, + &extended_schema, false, ) } @@ -350,7 +354,6 @@ pub fn parse_physical_expr( scalar_fun_def, args, convert_required!(e.return_type)?, - None, )) } ExprType::LikeExpr(like_expr) => Arc::new(LikeExpr::new( diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 1c5ba861d297..4de0b7c06d45 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -525,7 +525,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { let sort_exprs = &[]; let ordering_req = &[]; let ignore_nulls = false; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls) + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false) } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index ec215937dca8..b5b0b4c2247a 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -622,8 +622,8 @@ async fn roundtrip_expr_api() -> Result<()> { ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), first_value(vec![lit(1)], false, None, None, None), - covar_samp(lit(1.5), lit(2.2), false, None, None, None), - covar_pop(lit(1.5), lit(2.2), true, None, None, None), + covar_samp(lit(1.5), lit(2.2)), + covar_pop(lit(1.5), lit(2.2)), ]; // ensure expressions created with the expr api can be round tripped diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index c2018352c7cf..79abecf556da 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -253,8 +253,7 @@ fn roundtrip_nested_loop_join() -> Result<()> { fn roundtrip_window() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); let field_b = Field::new("b", DataType::Int64, false); - let field_c = Field::new("FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c])); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); let window_frame = WindowFrame::new_bounds( datafusion_expr::WindowFrameUnits::Range, @@ -426,6 +425,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { &schema, "example_agg", false, + false, )?]; roundtrip_test_with_context( @@ -624,7 +624,6 @@ fn roundtrip_scalar_udf() -> Result<()> { fun_def, vec![col("a", &schema)?], DataType::Int64, - None, ); let project = @@ -752,7 +751,6 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { Arc::new(udf.clone()), vec![col("text", &schema)?], DataType::Int64, - None, )); let filter = Arc::new(FilterExec::try_new( diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 3adf2960784d..dc0ddd4714dd 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -229,12 +229,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?; let order_by = (!order_by.is_empty()).then_some(order_by); let args = self.function_args_to_expr(args, schema, planner_context)?; - // TODO: Support filter and distinct for UDAFs + let filter: Option> = filter + .map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context)) + .transpose()? + .map(Box::new); return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( fm, args, - false, - None, + distinct, + filter, order_by, null_treatment, ))); diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 804fa6d306b4..a50bb7a69823 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -26,7 +26,7 @@ use datafusion_common::{ }; use datafusion_expr::{ expr::{Alias, Exists, InList, ScalarFunction, Sort, WindowFunction}, - Between, BinaryExpr, Case, Cast, Expr, Like, Operator, + Between, BinaryExpr, Case, Cast, Expr, Like, Operator, TryCast, }; use sqlparser::ast::{ self, Expr as AstExpr, Function, FunctionArg, Ident, UnaryOperator, @@ -356,6 +356,9 @@ impl Unparser<'_> { asc: _, nulls_first: _, }) => plan_err!("Sort expression should be handled by expr_to_unparsed"), + Expr::IsNull(expr) => { + Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql(expr)?))) + } Expr::IsNotNull(expr) => { Ok(ast::Expr::IsNotNull(Box::new(self.expr_to_sql(expr)?))) } @@ -368,6 +371,9 @@ impl Unparser<'_> { Expr::IsFalse(expr) => { Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql(expr)?))) } + Expr::IsNotFalse(expr) => { + Ok(ast::Expr::IsNotFalse(Box::new(self.expr_to_sql(expr)?))) + } Expr::IsUnknown(expr) => { Ok(ast::Expr::IsUnknown(Box::new(self.expr_to_sql(expr)?))) } @@ -388,27 +394,44 @@ impl Unparser<'_> { expr: Box::new(sql_parser_expr), }) } - Expr::ScalarVariable(_, _) => { - not_impl_err!("Unsupported Expr conversion: {expr:?}") + Expr::ScalarVariable(_, ids) => { + if ids.is_empty() { + return internal_err!("Not a valid ScalarVariable"); + } + + Ok(if ids.len() == 1 { + ast::Expr::Identifier( + self.new_ident_without_quote_style(ids[0].to_string()), + ) + } else { + ast::Expr::CompoundIdentifier( + ids.iter() + .map(|i| self.new_ident_without_quote_style(i.to_string())) + .collect(), + ) + }) } - Expr::IsNull(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), - Expr::IsNotFalse(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), Expr::GetIndexedField(_) => { not_impl_err!("Unsupported Expr conversion: {expr:?}") } - Expr::TryCast(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), + Expr::TryCast(TryCast { expr, data_type }) => { + let inner_expr = self.expr_to_sql(expr)?; + Ok(ast::Expr::TryCast { + expr: Box::new(inner_expr), + data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + format: None, + }) + } Expr::Wildcard { qualifier: _ } => { not_impl_err!("Unsupported Expr conversion: {expr:?}") } Expr::GroupingSet(_) => { not_impl_err!("Unsupported Expr conversion: {expr:?}") } - Expr::Placeholder(_) => { - not_impl_err!("Unsupported Expr conversion: {expr:?}") - } - Expr::OuterReferenceColumn(_, _) => { - not_impl_err!("Unsupported Expr conversion: {expr:?}") + Expr::Placeholder(p) => { + Ok(ast::Expr::Value(ast::Value::Placeholder(p.id.to_string()))) } + Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), Expr::Unnest(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), } } @@ -488,6 +511,13 @@ impl Unparser<'_> { } } + pub(super) fn new_ident_without_quote_style(&self, str: String) -> ast::Ident { + ast::Ident { + value: str, + quote_style: None, + } + } + pub(super) fn binary_op_to_sql( &self, lhs: ast::Expr, @@ -859,12 +889,14 @@ mod tests { use std::{any::Any, sync::Arc, vec}; use arrow::datatypes::{Field, Schema}; + use arrow_schema::DataType::Int8; use datafusion_common::TableReference; use datafusion_expr::{ case, col, exists, expr::{AggregateFunction, AggregateFunctionDefinition}, - lit, not, not_exists, table_scan, wildcard, ColumnarValue, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, + lit, not, not_exists, out_ref_col, placeholder, table_scan, try_cast, when, + wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, + WindowFrame, WindowFunctionDefinition, }; use crate::unparser::dialect::CustomDialect; @@ -933,6 +965,14 @@ mod tests { .otherwise(lit(ScalarValue::Null))?, r#"CASE "a" WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END"#, ), + ( + when(col("a").is_null(), lit(true)).otherwise(lit(false))?, + r#"CASE WHEN "a" IS NULL THEN true ELSE false END"#, + ), + ( + when(col("a").is_not_null(), lit(true)).otherwise(lit(false))?, + r#"CASE WHEN "a" IS NOT NULL THEN true ELSE false END"#, + ), ( Expr::Cast(Cast { expr: Box::new(col("a")), @@ -959,6 +999,18 @@ mod tests { ScalarUDF::new_from_impl(DummyUDF::new()).call(vec![col("a"), col("b")]), r#"dummy_udf("a", "b")"#, ), + ( + ScalarUDF::new_from_impl(DummyUDF::new()) + .call(vec![col("a"), col("b")]) + .is_null(), + r#"dummy_udf("a", "b") IS NULL"#, + ), + ( + ScalarUDF::new_from_impl(DummyUDF::new()) + .call(vec![col("a"), col("b")]) + .is_not_null(), + r#"dummy_udf("a", "b") IS NOT NULL"#, + ), ( Expr::Like(Like { negated: true, @@ -1081,6 +1133,7 @@ mod tests { r#"COUNT(*) OVER (ORDER BY "a" DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#, ), (col("a").is_not_null(), r#""a" IS NOT NULL"#), + (col("a").is_null(), r#""a" IS NULL"#), ( (col("a") + col("b")).gt(lit(4)).is_true(), r#"(("a" + "b") > 4) IS TRUE"#, @@ -1093,6 +1146,10 @@ mod tests { (col("a") + col("b")).gt(lit(4)).is_false(), r#"(("a" + "b") > 4) IS FALSE"#, ), + ( + (col("a") + col("b")).gt(lit(4)).is_not_false(), + r#"(("a" + "b") > 4) IS NOT FALSE"#, + ), ( (col("a") + col("b")).gt(lit(4)).is_unknown(), r#"(("a" + "b") > 4) IS UNKNOWN"#, @@ -1115,6 +1172,30 @@ mod tests { not_exists(Arc::new(dummy_logical_plan.clone())), r#"NOT EXISTS (SELECT "t"."a" FROM "t" WHERE ("t"."a" = 1))"#, ), + ( + try_cast(col("a"), DataType::Date64), + r#"TRY_CAST("a" AS DATETIME)"#, + ), + ( + try_cast(col("a"), DataType::UInt32), + r#"TRY_CAST("a" AS INTEGER UNSIGNED)"#, + ), + ( + Expr::ScalarVariable(Int8, vec![String::from("@a")]), + r#"@a"#, + ), + ( + Expr::ScalarVariable( + Int8, + vec![String::from("@root"), String::from("foo")], + ), + r#"@root.foo"#, + ), + (col("x").eq(placeholder("$1")), r#"("x" = $1)"#), + ( + out_ref_col(DataType::Int32, "t.a").gt(lit(1)), + r#"("t"."a" > 1)"#, + ), ]; for (expr, expected) in tests { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 40d66f9b52ce..983f8a085ba9 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -198,6 +198,73 @@ statement error This feature is not implemented: LIMIT not supported in ARRAY_AG SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100 +# Test distinct aggregate function with merge batch +query II +with A as ( + select 1 as id, 2 as foo + UNION ALL + select 1, null + UNION ALL + select 1, null + UNION ALL + select 1, 3 + UNION ALL + select 1, 2 + ---- The order is non-deterministic, verify with length +) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id; +---- +3 1 + +# It has only AggregateExec with FinalPartitioned mode, so `merge_batch` is used +# If the plan is changed, whether the `merge_batch` is used should be verified to ensure the test coverage +query TT +explain with A as ( + select 1 as id, 2 as foo + UNION ALL + select 1, null + UNION ALL + select 1, null + UNION ALL + select 1, 3 + UNION ALL + select 1, 2 +) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id; +---- +logical_plan +01)Projection: array_length(ARRAY_AGG(DISTINCT a.foo)), SUM(DISTINCT Int64(1)) +02)--Aggregate: groupBy=[[a.id]], aggr=[[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))]] +03)----SubqueryAlias: a +04)------SubqueryAlias: a +05)--------Union +06)----------Projection: Int64(1) AS id, Int64(2) AS foo +07)------------EmptyRelation +08)----------Projection: Int64(1) AS id, Int64(NULL) AS foo +09)------------EmptyRelation +10)----------Projection: Int64(1) AS id, Int64(NULL) AS foo +11)------------EmptyRelation +12)----------Projection: Int64(1) AS id, Int64(3) AS foo +13)------------EmptyRelation +14)----------Projection: Int64(1) AS id, Int64(2) AS foo +15)------------EmptyRelation +physical_plan +01)ProjectionExec: expr=[array_length(ARRAY_AGG(DISTINCT a.foo)@1) as array_length(ARRAY_AGG(DISTINCT a.foo)), SUM(DISTINCT Int64(1))@2 as SUM(DISTINCT Int64(1))] +02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=5 +05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))] +06)----------UnionExec +07)------------ProjectionExec: expr=[1 as id, 2 as foo] +08)--------------PlaceholderRowExec +09)------------ProjectionExec: expr=[1 as id, NULL as foo] +10)--------------PlaceholderRowExec +11)------------ProjectionExec: expr=[1 as id, NULL as foo] +12)--------------PlaceholderRowExec +13)------------ProjectionExec: expr=[1 as id, 3 as foo] +14)--------------PlaceholderRowExec +15)------------ProjectionExec: expr=[1 as id, 2 as foo] +16)--------------PlaceholderRowExec + + # FIX: custom absolute values # csv_query_avg_multi_batch @@ -2480,6 +2547,417 @@ Select bit_xor(DISTINCT c), arrow_typeof(bit_xor(DISTINCT c)) from t; statement ok drop table t; +################# +# Min_Max Begin # +################# +# min_decimal, max_decimal +statement ok +CREATE TABLE decimals (value DECIMAL(10, 2)); + +statement ok +INSERT INTO decimals VALUES (123.0001), (124.00); + +query RR +SELECT MIN(value), MAX(value) FROM decimals; +---- +123 124 + +statement ok +DROP TABLE decimals; + +statement ok +CREATE TABLE decimals_batch (value DECIMAL(10, 0)); + +statement ok +INSERT INTO decimals_batch VALUES (1), (2), (3), (4), (5); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_batch; +---- +1 5 + +statement ok +DROP TABLE decimals_batch; + +statement ok +CREATE TABLE decimals_empty (value DECIMAL(10, 0)); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_empty; +---- +NULL NULL + +statement ok +DROP TABLE decimals_empty; + +# min_decimal_all_nulls, max_decimal_all_nulls +statement ok +CREATE TABLE decimals_all_nulls (value DECIMAL(10, 0)); + +statement ok +INSERT INTO decimals_all_nulls VALUES (NULL), (NULL), (NULL), (NULL), (NULL), (NULL); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_all_nulls; +---- +NULL NULL + +statement ok +DROP TABLE decimals_all_nulls; + +# min_decimal_with_nulls, max_decimal_with_nulls +statement ok +CREATE TABLE decimals_with_nulls (value DECIMAL(10, 0)); + +statement ok +INSERT INTO decimals_with_nulls VALUES (1), (NULL), (3), (4), (5); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_with_nulls; +---- +1 5 + +statement ok +DROP TABLE decimals_with_nulls; + +statement ok +CREATE TABLE decimals_error (value DECIMAL(10, 2)); + +statement ok +INSERT INTO decimals_error VALUES (123.00), (arrow_cast(124.001, 'Decimal128(10, 3)')); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_error; +---- +123 124 + +statement ok +DROP TABLE decimals_error; + +statement ok +CREATE TABLE decimals_agg (value DECIMAL(10, 0)); + +statement ok +INSERT INTO decimals_agg VALUES (1), (2), (3), (4), (5); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_agg; +---- +1 5 + +statement ok +DROP TABLE decimals_agg; + +# min_i32, max_i32 +statement ok +CREATE TABLE integers (value INT); + +statement ok +INSERT INTO integers VALUES (1), (2), (3), (4), (5); + +query II +SELECT MIN(value), MAX(value) FROM integers +---- +1 5 + +statement ok +DROP TABLE integers; + +# min_utf8, max_utf8 +statement ok +CREATE TABLE strings (value TEXT); + +statement ok +INSERT INTO strings VALUES ('d'), ('a'), ('c'), ('b'); + +query TT +SELECT MIN(value), MAX(value) FROM strings +---- +a d + +statement ok +DROP TABLE strings; + +# min_i32_with_nulls, max_i32_with_nulls +statement ok +CREATE TABLE integers_with_nulls (value INT); + +statement ok +INSERT INTO integers_with_nulls VALUES (1), (NULL), (3), (4), (5); + +query II +SELECT MIN(value), MAX(value) FROM integers_with_nulls +---- +1 5 + +statement ok +DROP TABLE integers_with_nulls; + +# min_i32_all_nulls, max_i32_all_nulls +statement ok +CREATE TABLE integers_all_nulls (value INT); + +query II +SELECT MIN(value), MAX(value) FROM integers_all_nulls +---- +NULL NULL + +statement ok +DROP TABLE integers_all_nulls; + +# min_u32, max_u32 +statement ok +CREATE TABLE uintegers (value INT UNSIGNED); + +statement ok +INSERT INTO uintegers VALUES (1), (2), (3), (4), (5); + +query II +SELECT MIN(value), MAX(value) FROM uintegers +---- +1 5 + +statement ok +DROP TABLE uintegers; + +# min_f32, max_f32 +statement ok +CREATE TABLE floats (value FLOAT); + +statement ok +INSERT INTO floats VALUES (1.0), (2.0), (3.0), (4.0), (5.0); + +query RR +SELECT MIN(value), MAX(value) FROM floats +---- +1 5 + +statement ok +DROP TABLE floats; + +# min_f64, max_f64 +statement ok +CREATE TABLE doubles (value DOUBLE); + +statement ok +INSERT INTO doubles VALUES (1.0), (2.0), (3.0), (4.0), (5.0); + +query RR +SELECT MIN(value), MAX(value) FROM doubles +---- +1 5 + +statement ok +DROP TABLE doubles; + +# min_date, max_date +statement ok +CREATE TABLE dates (value DATE); + +statement ok +INSERT INTO dates VALUES ('1970-01-02'), ('1970-01-03'), ('1970-01-04'), ('1970-01-05'), ('1970-01-06'); + +query DD +SELECT MIN(value), MAX(value) FROM dates +---- +1970-01-02 1970-01-06 + +statement ok +DROP TABLE dates; + +# min_seconds, max_seconds +statement ok +CREATE TABLE times (value TIME); + +statement ok +INSERT INTO times VALUES ('00:00:01'), ('00:00:02'), ('00:00:03'), ('00:00:04'), ('00:00:05'); + +query DD +SELECT MIN(value), MAX(value) FROM times +---- +00:00:01 00:00:05 + +statement ok +DROP TABLE times; + +# min_milliseconds, max_milliseconds +statement ok +CREATE TABLE time32millisecond (value TIME); + +statement ok +INSERT INTO time32millisecond VALUES ('00:00:00.001'), ('00:00:00.002'), ('00:00:00.003'), ('00:00:00.004'), ('00:00:00.005'); + +query DD +SELECT MIN(value), MAX(value) FROM time32millisecond +---- +00:00:00.001 00:00:00.005 + +statement ok +DROP TABLE time32millisecond; + +# min_microseconds, max_microseconds +statement ok +CREATE TABLE time64microsecond (value TIME); + +statement ok +INSERT INTO time64microsecond VALUES ('00:00:00.000001'), ('00:00:00.000002'), ('00:00:00.000003'), ('00:00:00.000004'), ('00:00:00.000005'); + +query DD +SELECT MIN(value), MAX(value) FROM time64microsecond +---- +00:00:00.000001 00:00:00.000005 + +statement ok +DROP TABLE time64microsecond; + +# min_nanoseconds, max_nanoseconds +statement ok +CREATE TABLE time64nanosecond (value TIME); + +statement ok +INSERT INTO time64nanosecond VALUES ('00:00:00.000000001'), ('00:00:00.000000002'), ('00:00:00.000000003'), ('00:00:00.000000004'), ('00:00:00.000000005'); + +query DD +SELECT MIN(value), MAX(value) FROM time64nanosecond +---- +00:00:00.000000001 00:00:00.000000005 + +statement ok +DROP TABLE time64nanosecond; + +# min_timestamp, max_timestamp +statement ok +CREATE TABLE timestampmicrosecond (value TIMESTAMP); + +statement ok +INSERT INTO timestampmicrosecond VALUES ('1970-01-01 00:00:00.000001'), ('1970-01-01 00:00:00.000002'), ('1970-01-01 00:00:00.000003'), ('1970-01-01 00:00:00.000004'), ('1970-01-01 00:00:00.000005'); + +query PP +SELECT MIN(value), MAX(value) FROM timestampmicrosecond +---- +1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000005 + +statement ok +DROP TABLE timestampmicrosecond; + +# max_bool +statement ok +CREATE TABLE max_bool (value BOOLEAN); + +statement ok +INSERT INTO max_bool VALUES (false), (false); + +query B +SELECT MAX(value) FROM max_bool +---- +false + +statement ok +DROP TABLE max_bool; + +statement ok +CREATE TABLE max_bool (value BOOLEAN); + +statement ok +INSERT INTO max_bool VALUES (true), (true); + +query B +SELECT MAX(value) FROM max_bool +---- +true + +statement ok +DROP TABLE max_bool; + +statement ok +CREATE TABLE max_bool (value BOOLEAN); + +statement ok +INSERT INTO max_bool VALUES (false), (true), (false); + +query B +SELECT MAX(value) FROM max_bool +---- +true + +statement ok +DROP TABLE max_bool; + +statement ok +CREATE TABLE max_bool (value BOOLEAN); + +statement ok +INSERT INTO max_bool VALUES (true), (false), (true); + +query B +SELECT MAX(value) FROM max_bool +---- +true + +statement ok +DROP TABLE max_bool; + +# min_bool +statement ok +CREATE TABLE min_bool (value BOOLEAN); + +statement ok +INSERT INTO min_bool VALUES (false), (false); + +query B +SELECT MIN(value) FROM min_bool +---- +false + +statement ok +DROP TABLE min_bool; + +statement ok +CREATE TABLE min_bool (value BOOLEAN); + +statement ok +INSERT INTO min_bool VALUES (true), (true); + +query B +SELECT MIN(value) FROM min_bool +---- +true + +statement ok +DROP TABLE min_bool; + +statement ok +CREATE TABLE min_bool (value BOOLEAN); + +statement ok +INSERT INTO min_bool VALUES (false), (true), (false); + +query B +SELECT MIN(value) FROM min_bool +---- +false + +statement ok +DROP TABLE min_bool; + +statement ok +CREATE TABLE min_bool (value BOOLEAN); + +statement ok +INSERT INTO min_bool VALUES (true), (false), (true); + +query B +SELECT MIN(value) FROM min_bool +---- +false + +statement ok +DROP TABLE min_bool; + +################# +# Min_Max End # +################# + statement ok create table bool_aggregate_functions ( c1 boolean not null, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index eeb5dc01b6e7..9b8b50201243 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1963,6 +1963,18 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), co [1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] [5] [, 54, 55, 56, 57, 58, 59, 60] [55] +# Test issue: https://github.com/apache/datafusion/issues/10425 +# `from` may be larger than `to` and `stride` is positive +query ???? +select array_slice(a, -1, 2, 1), array_slice(a, -1, 2), + array_slice(a, 3, 2, 1), array_slice(a, 3, 2) + from (values ([1.0, 2.0, 3.0, 3.0]), ([4.0, 5.0, 3.0]), ([6.0])) t(a); +---- +[] [] [] [] +[] [] [] [] +[6.0] [6.0] [] [] + + # make_array with nulls query ??????? select make_array(make_array('a','b'), null), diff --git a/datafusion/sqllogictest/test_files/dates.slt b/datafusion/sqllogictest/test_files/dates.slt index 32c0bd14e7cc..e21637bd8913 100644 --- a/datafusion/sqllogictest/test_files/dates.slt +++ b/datafusion/sqllogictest/test_files/dates.slt @@ -224,5 +224,11 @@ SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', query error function unsupported data type at index 1: SELECT to_date(t.ts, make_array('%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+')) from ts_utf8_data as t +# verify to_date with format +query D +select to_date('2022-01-23', '%Y-%m-%d'); +---- +2022-01-23 + statement ok drop table ts_utf8_data diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 3a4ac747ebd6..92c537f975ad 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -252,9 +252,9 @@ physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true -physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] @@ -311,9 +311,9 @@ physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] -physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] @@ -348,9 +348,9 @@ physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 -physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan 01)GlobalLimitExec: skip=0, fetch=10 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 0f869fc0b419..fb07d5ebe895 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -955,3 +955,154 @@ drop table foo; statement ok drop table ambiguity_test; + +# Casting from numeric to string types breaks the ordering +statement ok +CREATE EXTERNAL TABLE ordered_table ( + a0 INT, + a INT, + b INT, + c INT, + d INT +) +STORED AS CSV +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); + +query T +SELECT CAST(c as VARCHAR) as c_str +FROM ordered_table +ORDER BY c_str +limit 5; +---- +0 +1 +10 +11 +12 + +query TT +EXPLAIN SELECT CAST(c as VARCHAR) as c_str +FROM ordered_table +ORDER BY c_str +limit 5; +---- +logical_plan +01)Limit: skip=0, fetch=5 +02)--Sort: c_str ASC NULLS LAST, fetch=5 +03)----Projection: CAST(ordered_table.c AS Utf8) AS c_str +04)------TableScan: ordered_table projection=[c] +physical_plan +01)GlobalLimitExec: skip=0, fetch=5 +02)--SortPreservingMergeExec: [c_str@0 ASC NULLS LAST], fetch=5 +03)----SortExec: TopK(fetch=5), expr=[c_str@0 ASC NULLS LAST], preserve_partitioning=[true] +04)------ProjectionExec: expr=[CAST(c@0 AS Utf8) as c_str] +05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true + + +# Casting from numeric to numeric types preserves the ordering +query I +SELECT CAST(c as BIGINT) as c_bigint +FROM ordered_table +ORDER BY c_bigint +limit 5; +---- +0 +1 +2 +3 +4 + +query TT +EXPLAIN SELECT CAST(c as BIGINT) as c_bigint +FROM ordered_table +ORDER BY c_bigint +limit 5; +---- +logical_plan +01)Limit: skip=0, fetch=5 +02)--Sort: c_bigint ASC NULLS LAST, fetch=5 +03)----Projection: CAST(ordered_table.c AS Int64) AS c_bigint +04)------TableScan: ordered_table projection=[c] +physical_plan +01)GlobalLimitExec: skip=0, fetch=5 +02)--SortPreservingMergeExec: [c_bigint@0 ASC NULLS LAST], fetch=5 +03)----ProjectionExec: expr=[CAST(c@0 AS Int64) as c_bigint] +04)------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +statement ok +drop table ordered_table; + + +# ABS(x) breaks the ordering if x's range contains both negative and positive values. +# Since x is defined as INT, its range is assumed to be from NEG_INF to INF. +statement ok +CREATE EXTERNAL TABLE ordered_table ( + a0 INT, + a INT, + b INT, + c INT, + d INT +) +STORED AS CSV +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); + +query TT +EXPLAIN SELECT ABS(c) as abs_c +FROM ordered_table +ORDER BY abs_c +limit 5; +---- +logical_plan +01)Limit: skip=0, fetch=5 +02)--Sort: abs_c ASC NULLS LAST, fetch=5 +03)----Projection: abs(ordered_table.c) AS abs_c +04)------TableScan: ordered_table projection=[c] +physical_plan +01)GlobalLimitExec: skip=0, fetch=5 +02)--SortPreservingMergeExec: [abs_c@0 ASC NULLS LAST], fetch=5 +03)----SortExec: TopK(fetch=5), expr=[abs_c@0 ASC NULLS LAST], preserve_partitioning=[true] +04)------ProjectionExec: expr=[abs(c@0) as abs_c] +05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +statement ok +drop table ordered_table; + +# ABS(x) preserves the ordering if x's range falls into positive values. +# Since x is defined as INT UNSIGNED, its range is assumed to be from 0 to INF. +statement ok +CREATE EXTERNAL TABLE ordered_table ( + a0 INT, + a INT, + b INT, + c INT UNSIGNED, + d INT +) +STORED AS CSV +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); + +query TT +EXPLAIN SELECT ABS(c) as abs_c +FROM ordered_table +ORDER BY abs_c +limit 5; +---- +logical_plan +01)Limit: skip=0, fetch=5 +02)--Sort: abs_c ASC NULLS LAST, fetch=5 +03)----Projection: abs(ordered_table.c) AS abs_c +04)------TableScan: ordered_table projection=[c] +physical_plan +01)GlobalLimitExec: skip=0, fetch=5 +02)--SortPreservingMergeExec: [abs_c@0 ASC NULLS LAST], fetch=5 +03)----ProjectionExec: expr=[abs(c@0) as abs_c] +04)------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 09a2aa3e7436..7b7e355fa2b5 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -263,5 +263,22 @@ DROP TABLE t1; statement ok DROP TABLE t2; +# Set batch size to 1 for sort merge join to test scenario when data spread across multiple batches +statement ok +set datafusion.execution.batch_size = 1; + +query II +SELECT * FROM ( + WITH + t1 AS ( + SELECT 12 a, 12 b + ), + t2 AS ( + SELECT 12 a, 12 b + ) + SELECT t1.* FROM t1 JOIN t2 on t1.a = t2.b WHERE t1.a > t2.b +) ORDER BY 1, 2; +---- + statement ok set datafusion.optimizer.prefer_hash_join = true; diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 13fb8fba0d31..5f75bca4f0fa 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -2795,3 +2795,9 @@ SELECT '2000-12-01 04:04:12' AT TIME ZONE 'America/New York'; # abbreviated timezone is not supported statement error SELECT '2023-03-12 02:00:00' AT TIME ZONE 'EDT'; + +# Test current_time without parentheses +query B +select current_time = current_time; +---- +true diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index dce8ce10b587..e4be6e68ff16 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -39,7 +39,7 @@ itertools = { workspace = true } object_store = { workspace = true } prost = "0.12" prost-types = "0.12" -substrait = "0.32.0" +substrait = "0.33.3" [dev-dependencies] tokio = { workspace = true } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index db5d341bc225..6f0738c38df5 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -722,7 +722,10 @@ pub fn to_substrait_agg_measure( arguments, sorts, output_type: None, - invocation: AggregationInvocation::All as i32, + invocation: match distinct { + true => AggregationInvocation::Distinct as i32, + false => AggregationInvocation::All as i32, + }, phase: AggregationPhase::Unspecified as i32, args: vec![], options: vec![], diff --git a/docs/source/contributor-guide/getting_started.md b/docs/source/contributor-guide/getting_started.md new file mode 100644 index 000000000000..64d5a0d43d5d --- /dev/null +++ b/docs/source/contributor-guide/getting_started.md @@ -0,0 +1,87 @@ + + +# Getting Started + +This section describes how you can get started at developing DataFusion. + +## Windows setup + +```shell +wget https://az792536.vo.msecnd.net/vms/VMBuild_20190311/VirtualBox/MSEdge/MSEdge.Win10.VirtualBox.zip +choco install -y git rustup.install visualcpp-build-tools +git-bash.exe +cargo build +``` + +## Protoc Installation + +Compiling DataFusion from sources requires an installed version of the protobuf compiler, `protoc`. + +On most platforms this can be installed from your system's package manager + +``` +# Ubuntu +$ sudo apt install -y protobuf-compiler + +# Fedora +$ dnf install -y protobuf-devel + +# Arch Linux +$ pacman -S protobuf + +# macOS +$ brew install protobuf +``` + +You will want to verify the version installed is `3.12` or greater, which introduced support for explicit [field presence](https://github.com/protocolbuffers/protobuf/blob/v3.12.0/docs/field_presence.md). Older versions may fail to compile. + +```shell +$ protoc --version +libprotoc 3.12.4 +``` + +Alternatively a binary release can be downloaded from the [Release Page](https://github.com/protocolbuffers/protobuf/releases) or [built from source](https://github.com/protocolbuffers/protobuf/blob/main/src/README.md). + +## Bootstrap environment + +DataFusion is written in Rust and it uses a standard rust toolkit: + +- `cargo build` +- `cargo fmt` to format the code +- `cargo test` to test +- etc. + +Note that running `cargo test` requires significant memory resources, due to cargo running many tests in parallel by default. If you run into issues with slow tests or system lock ups, you can significantly reduce the memory required by instead running `cargo test -- --test-threads=1`. For more information see [this issue](https://github.com/apache/datafusion/issues/5347). + +Testing setup: + +- `rustup update stable` DataFusion uses the latest stable release of rust +- `git submodule init` +- `git submodule update` + +Formatting instructions: + +- [ci/scripts/rust_fmt.sh](../../../ci/scripts/rust_fmt.sh) +- [ci/scripts/rust_clippy.sh](../../../ci/scripts/rust_clippy.sh) +- [ci/scripts/rust_toml_fmt.sh](../../../ci/scripts/rust_toml_fmt.sh) + +or run them all at once: + +- [dev/rust_lint.sh](../../../dev/rust_lint.sh) diff --git a/docs/source/contributor-guide/howtos.md b/docs/source/contributor-guide/howtos.md new file mode 100644 index 000000000000..254b1de6521e --- /dev/null +++ b/docs/source/contributor-guide/howtos.md @@ -0,0 +1,129 @@ + + +# HOWTOs + +## How to add a new scalar function + +Below is a checklist of what you need to do to add a new scalar function to DataFusion: + +- Add the actual implementation of the function to a new module file within: + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions-array) for array functions + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/crypto) for crypto functions + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/datetime) for datetime functions + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/encoding) for encoding functions + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/math) for math functions + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/regex) for regex functions + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/string) for string functions + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/unicode) for unicode functions + - create a new module [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/) for other functions. +- New function modules - for example a `vector` module, should use a [rust feature](https://doc.rust-lang.org/cargo/reference/features.html) (for example `vector_expressions`) to allow DataFusion + users to enable or disable the new module as desired. +- The implementation of the function is done via implementing `ScalarUDFImpl` trait for the function struct. + - See the [advanced_udf.rs] example for an example implementation + - Add tests for the new function +- To connect the implementation of the function add to the mod.rs file: + - a `mod xyz;` where xyz is the new module file + - a call to `make_udf_function!(..);` + - an item in `export_functions!(..);` +- In [sqllogictest/test_files], add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. + - Documentation for `sqllogictest` [here](https://github.com/apache/datafusion/blob/main/datafusion/sqllogictest/README.md) +- Add SQL reference documentation [here](https://github.com/apache/datafusion/blob/main/docs/source/user-guide/sql/scalar_functions.md) + +[advanced_udf.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs +[sqllogictest/test_files]: https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest/test_files + +## How to add a new aggregate function + +Below is a checklist of what you need to do to add a new aggregate function to DataFusion: + +- Add the actual implementation of an `Accumulator` and `AggregateExpr`: +- In [datafusion/expr/src](../../../datafusion/expr/src/aggregate_function.rs), add: + - a new variant to `AggregateFunction` + - a new entry to `FromStr` with the name of the function as called by SQL + - a new line in `return_type` with the expected return type of the function, given an incoming type + - a new line in `signature` with the signature of the function (number and types of its arguments) + - a new line in `create_aggregate_expr` mapping the built-in to the implementation + - tests to the function. +- In [sqllogictest/test_files], add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. + - Documentation for `sqllogictest` [here](https://github.com/apache/datafusion/blob/main/datafusion/sqllogictest/README.md) +- Add SQL reference documentation [here](https://github.com/apache/datafusion/blob/main/docs/source/user-guide/sql/aggregate_functions.md) + +## How to display plans graphically + +The query plans represented by `LogicalPlan` nodes can be graphically +rendered using [Graphviz](https://www.graphviz.org/). + +To do so, save the output of the `display_graphviz` function to a file.: + +```rust +// Create plan somehow... +let mut output = File::create("/tmp/plan.dot")?; +write!(output, "{}", plan.display_graphviz()); +``` + +Then, use the `dot` command line tool to render it into a file that +can be displayed. For example, the following command creates a +`/tmp/plan.pdf` file: + +```bash +dot -Tpdf < /tmp/plan.dot > /tmp/plan.pdf +``` + +## How to format `.md` document + +We are using `prettier` to format `.md` files. + +You can either use `npm i -g prettier` to install it globally or use `npx` to run it as a standalone binary. Using `npx` required a working node environment. Upgrading to the latest prettier is recommended (by adding `--upgrade` to the `npm` command). + +```bash +$ prettier --version +2.3.0 +``` + +After you've confirmed your prettier version, you can format all the `.md` files: + +```bash +prettier -w {datafusion,datafusion-cli,datafusion-examples,dev,docs}/**/*.md +``` + +## How to format `.toml` files + +We use `taplo` to format `.toml` files. + +For Rust developers, you can install it via: + +```sh +cargo install taplo-cli --locked +``` + +> Refer to the [Installation section][doc] on other ways to install it. +> +> [doc]: https://taplo.tamasfe.dev/cli/installation/binary.html + +```bash +$ taplo --version +taplo 0.9.0 +``` + +After you've confirmed your `taplo` version, you can format all the `.toml` files: + +```bash +taplo fmt +``` diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index 5705737206da..9aaa8b045388 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -113,232 +113,6 @@ The good thing about open code and open development is that any issues in one ch Pull requests will be marked with a `stale` label after 60 days of inactivity and then closed 7 days after that. Commenting on the PR will remove the `stale` label. -## Getting Started - -This section describes how you can get started at developing DataFusion. - -### Windows setup - -```shell -wget https://az792536.vo.msecnd.net/vms/VMBuild_20190311/VirtualBox/MSEdge/MSEdge.Win10.VirtualBox.zip -choco install -y git rustup.install visualcpp-build-tools -git-bash.exe -cargo build -``` - -### Protoc Installation - -Compiling DataFusion from sources requires an installed version of the protobuf compiler, `protoc`. - -On most platforms this can be installed from your system's package manager - -``` -# Ubuntu -$ sudo apt install -y protobuf-compiler - -# Fedora -$ dnf install -y protobuf-devel - -# Arch Linux -$ pacman -S protobuf - -# macOS -$ brew install protobuf -``` - -You will want to verify the version installed is `3.12` or greater, which introduced support for explicit [field presence](https://github.com/protocolbuffers/protobuf/blob/v3.12.0/docs/field_presence.md). Older versions may fail to compile. - -```shell -$ protoc --version -libprotoc 3.12.4 -``` - -Alternatively a binary release can be downloaded from the [Release Page](https://github.com/protocolbuffers/protobuf/releases) or [built from source](https://github.com/protocolbuffers/protobuf/blob/main/src/README.md). - -### Bootstrap environment - -DataFusion is written in Rust and it uses a standard rust toolkit: - -- `cargo build` -- `cargo fmt` to format the code -- `cargo test` to test -- etc. - -Note that running `cargo test` requires significant memory resources, due to cargo running many tests in parallel by default. If you run into issues with slow tests or system lock ups, you can significantly reduce the memory required by instead running `cargo test -- --test-threads=1`. For more information see [this issue](https://github.com/apache/datafusion/issues/5347). - -Testing setup: - -- `rustup update stable` DataFusion uses the latest stable release of rust -- `git submodule init` -- `git submodule update` - -Formatting instructions: - -- [ci/scripts/rust_fmt.sh](../../../ci/scripts/rust_fmt.sh) -- [ci/scripts/rust_clippy.sh](../../../ci/scripts/rust_clippy.sh) -- [ci/scripts/rust_toml_fmt.sh](../../../ci/scripts/rust_toml_fmt.sh) - -or run them all at once: - -- [dev/rust_lint.sh](../../../dev/rust_lint.sh) - -## Testing - -Tests are critical to ensure that DataFusion is working properly and -is not accidentally broken during refactorings. All new features -should have test coverage. - -DataFusion has several levels of tests in its [Test -Pyramid](https://martinfowler.com/articles/practical-test-pyramid.html) -and tries to follow the Rust standard [Testing Organization](https://doc.rust-lang.org/book/ch11-03-test-organization.html) in the The Book. - -### Unit tests - -Tests for code in an individual module are defined in the same source file with a `test` module, following Rust convention. - -### sqllogictests Tests - -DataFusion's SQL implementation is tested using [sqllogictest](https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest) which are run like any other Rust test using `cargo test --test sqllogictests`. - -`sqllogictests` tests may be less convenient for new contributors who are familiar with writing `.rs` tests as they require learning another tool. However, `sqllogictest` based tests are much easier to develop and maintain as they 1) do not require a slow recompile/link cycle and 2) can be automatically updated via `cargo test --test sqllogictests -- --complete`. - -Like similar systems such as [DuckDB](https://duckdb.org/dev/testing), DataFusion has chosen to trade off a slightly higher barrier to contribution for longer term maintainability. - -### Rust Integration Tests - -There are several tests of the public interface of the DataFusion library in the [tests](https://github.com/apache/datafusion/tree/main/datafusion/core/tests) directory. - -You can run these tests individually using `cargo` as normal command such as - -```shell -cargo test -p datafusion --test parquet_exec -``` - -## Benchmarks - -### Criterion Benchmarks - -[Criterion](https://docs.rs/criterion/latest/criterion/index.html) is a statistics-driven micro-benchmarking framework used by DataFusion for evaluating the performance of specific code-paths. In particular, the criterion benchmarks help to both guide optimisation efforts, and prevent performance regressions within DataFusion. - -Criterion integrates with Cargo's built-in [benchmark support](https://doc.rust-lang.org/cargo/commands/cargo-bench.html) and a given benchmark can be run with - -``` -cargo bench --bench BENCHMARK_NAME -``` - -A full list of benchmarks can be found [here](https://github.com/apache/datafusion/tree/main/datafusion/core/benches). - -_[cargo-criterion](https://github.com/bheisler/cargo-criterion) may also be used for more advanced reporting._ - -### Parquet SQL Benchmarks - -The parquet SQL benchmarks can be run with - -``` - cargo bench --bench parquet_query_sql -``` - -These randomly generate a parquet file, and then benchmark queries sourced from [parquet_query_sql.sql](../../../datafusion/core/benches/parquet_query_sql.sql) against it. This can therefore be a quick way to add coverage of particular query and/or data paths. - -If the environment variable `PARQUET_FILE` is set, the benchmark will run queries against this file instead of a randomly generated one. This can be useful for performing multiple runs, potentially with different code, against the same source data, or for testing against a custom dataset. - -The benchmark will automatically remove any generated parquet file on exit, however, if interrupted (e.g. by CTRL+C) it will not. This can be useful for analysing the particular file after the fact, or preserving it to use with `PARQUET_FILE` in subsequent runs. - -### Comparing Baselines - -By default, Criterion.rs will compare the measurements against the previous run (if any). Sometimes it's useful to keep a set of measurements around for several runs. For example, you might want to make multiple changes to the code while comparing against the master branch. For this situation, Criterion.rs supports custom baselines. - -``` - git checkout main - cargo bench --bench sql_planner -- --save-baseline main - git checkout YOUR_BRANCH - cargo bench --bench sql_planner -- --baseline main -``` - -Note: For MacOS it may be required to run `cargo bench` with `sudo` - -``` -sudo cargo bench ... -``` - -More information on [Baselines](https://bheisler.github.io/criterion.rs/book/user_guide/command_line_options.html#baselines) - -### Upstream Benchmark Suites - -Instructions and tooling for running upstream benchmark suites against DataFusion can be found in [benchmarks](https://github.com/apache/datafusion/tree/main/benchmarks). - -These are valuable for comparative evaluation against alternative Arrow implementations and query engines. - -## HOWTOs - -### How to add a new scalar function - -Below is a checklist of what you need to do to add a new scalar function to DataFusion: - -- Add the actual implementation of the function to a new module file within: - - [here](../../../datafusion/functions-array/src) for array functions - - [here](../../../datafusion/functions/src/crypto) for crypto functions - - [here](../../../datafusion/functions/src/datetime) for datetime functions - - [here](../../../datafusion/functions/src/encoding) for encoding functions - - [here](../../../datafusion/functions/src/math) for math functions - - [here](../../../datafusion/functions/src/regex) for regex functions - - [here](../../../datafusion/functions/src/string) for string functions - - [here](../../../datafusion/functions/src/unicode) for unicode functions - - create a new module [here](../../../datafusion/functions/src) for other functions. -- New function modules - for example a `vector` module, should use a [rust feature](https://doc.rust-lang.org/cargo/reference/features.html) (for example `vector_expressions`) to allow DataFusion - users to enable or disable the new module as desired. -- The implementation of the function is done via implementing `ScalarUDFImpl` trait for the function struct. - - See the [advanced_udf.rs](../../../datafusion-examples/examples/advanced_udf.rs) example for an example implementation - - Add tests for the new function -- To connect the implementation of the function add to the mod.rs file: - - a `mod xyz;` where xyz is the new module file - - a call to `make_udf_function!(..);` - - an item in `export_functions!(..);` -- In [sqllogictest/test_files](../../../datafusion/sqllogictest/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. - - Documentation for `sqllogictest` [here](../../../datafusion/sqllogictest/README.md) -- Add SQL reference documentation [here](../../../docs/source/user-guide/sql/scalar_functions.md) - -### How to add a new aggregate function - -Below is a checklist of what you need to do to add a new aggregate function to DataFusion: - -- Add the actual implementation of an `Accumulator` and `AggregateExpr`: - - [here](../../../datafusion/physical-expr/src/string_expressions.rs) for string functions - - [here](../../../datafusion/physical-expr/src/math_expressions.rs) for math functions - - [here](../../../datafusion/functions/src/datetime/mod.rs) for datetime functions - - create a new module [here](../../../datafusion/physical-expr/src) for other functions -- In [datafusion/expr/src](../../../datafusion/expr/src/aggregate_function.rs), add: - - a new variant to `AggregateFunction` - - a new entry to `FromStr` with the name of the function as called by SQL - - a new line in `return_type` with the expected return type of the function, given an incoming type - - a new line in `signature` with the signature of the function (number and types of its arguments) - - a new line in `create_aggregate_expr` mapping the built-in to the implementation - - tests to the function. -- In [sqllogictest/test_files](../../../datafusion/sqllogictest/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. - - Documentation for `sqllogictest` [here](../../../datafusion/sqllogictest/README.md) -- Add SQL reference documentation [here](../../../docs/source/user-guide/sql/aggregate_functions.md) - -### How to display plans graphically - -The query plans represented by `LogicalPlan` nodes can be graphically -rendered using [Graphviz](https://www.graphviz.org/). - -To do so, save the output of the `display_graphviz` function to a file.: - -```rust -// Create plan somehow... -let mut output = File::create("/tmp/plan.dot")?; -write!(output, "{}", plan.display_graphviz()); -``` - -Then, use the `dot` command line tool to render it into a file that -can be displayed. For example, the following command creates a -`/tmp/plan.pdf` file: - -```bash -dot -Tpdf < /tmp/plan.dot > /tmp/plan.pdf -``` - ## Specifications We formalize some DataFusion semantics and behaviors through specification @@ -354,45 +128,3 @@ Here is the list current active specifications: - [Invariants](https://datafusion.apache.org/contributor-guide/specification/invariants.html) All specifications are stored in the `docs/source/specification` folder. - -## How to format `.md` document - -We are using `prettier` to format `.md` files. - -You can either use `npm i -g prettier` to install it globally or use `npx` to run it as a standalone binary. Using `npx` required a working node environment. Upgrading to the latest prettier is recommended (by adding `--upgrade` to the `npm` command). - -```bash -$ prettier --version -2.3.0 -``` - -After you've confirmed your prettier version, you can format all the `.md` files: - -```bash -prettier -w {datafusion,datafusion-cli,datafusion-examples,dev,docs}/**/*.md -``` - -## How to format `.toml` files - -We use `taplo` to format `.toml` files. - -For Rust developers, you can install it via: - -```sh -cargo install taplo-cli --locked -``` - -> Refer to the [Installation section][doc] on other ways to install it. -> -> [doc]: https://taplo.tamasfe.dev/cli/installation/binary.html - -```bash -$ taplo --version -taplo 0.9.0 -``` - -After you've confirmed your `taplo` version, you can format all the `.toml` files: - -```bash -taplo fmt -``` diff --git a/docs/source/contributor-guide/testing.md b/docs/source/contributor-guide/testing.md new file mode 100644 index 000000000000..11f53bcb2a2d --- /dev/null +++ b/docs/source/contributor-guide/testing.md @@ -0,0 +1,105 @@ + + +# Testing + +Tests are critical to ensure that DataFusion is working properly and +is not accidentally broken during refactorings. All new features +should have test coverage. + +DataFusion has several levels of tests in its [Test +Pyramid](https://martinfowler.com/articles/practical-test-pyramid.html) +and tries to follow the Rust standard [Testing Organization](https://doc.rust-lang.org/book/ch11-03-test-organization.html) in the The Book. + +## Unit tests + +Tests for code in an individual module are defined in the same source file with a `test` module, following Rust convention. + +## sqllogictests Tests + +DataFusion's SQL implementation is tested using [sqllogictest](https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest) which are run like any other Rust test using `cargo test --test sqllogictests`. + +`sqllogictests` tests may be less convenient for new contributors who are familiar with writing `.rs` tests as they require learning another tool. However, `sqllogictest` based tests are much easier to develop and maintain as they 1) do not require a slow recompile/link cycle and 2) can be automatically updated via `cargo test --test sqllogictests -- --complete`. + +Like similar systems such as [DuckDB](https://duckdb.org/dev/testing), DataFusion has chosen to trade off a slightly higher barrier to contribution for longer term maintainability. + +### Rust Integration Tests + +There are several tests of the public interface of the DataFusion library in the [tests](https://github.com/apache/datafusion/tree/main/datafusion/core/tests) directory. + +You can run these tests individually using `cargo` as normal command such as + +```shell +cargo test -p datafusion --test parquet_exec +``` + +## Benchmarks + +### Criterion Benchmarks + +[Criterion](https://docs.rs/criterion/latest/criterion/index.html) is a statistics-driven micro-benchmarking framework used by DataFusion for evaluating the performance of specific code-paths. In particular, the criterion benchmarks help to both guide optimisation efforts, and prevent performance regressions within DataFusion. + +Criterion integrates with Cargo's built-in [benchmark support](https://doc.rust-lang.org/cargo/commands/cargo-bench.html) and a given benchmark can be run with + +``` +cargo bench --bench BENCHMARK_NAME +``` + +A full list of benchmarks can be found [here](https://github.com/apache/datafusion/tree/main/datafusion/core/benches). + +_[cargo-criterion](https://github.com/bheisler/cargo-criterion) may also be used for more advanced reporting._ + +### Parquet SQL Benchmarks + +The parquet SQL benchmarks can be run with + +``` + cargo bench --bench parquet_query_sql +``` + +These randomly generate a parquet file, and then benchmark queries sourced from [parquet_query_sql.sql](../../../datafusion/core/benches/parquet_query_sql.sql) against it. This can therefore be a quick way to add coverage of particular query and/or data paths. + +If the environment variable `PARQUET_FILE` is set, the benchmark will run queries against this file instead of a randomly generated one. This can be useful for performing multiple runs, potentially with different code, against the same source data, or for testing against a custom dataset. + +The benchmark will automatically remove any generated parquet file on exit, however, if interrupted (e.g. by CTRL+C) it will not. This can be useful for analysing the particular file after the fact, or preserving it to use with `PARQUET_FILE` in subsequent runs. + +### Comparing Baselines + +By default, Criterion.rs will compare the measurements against the previous run (if any). Sometimes it's useful to keep a set of measurements around for several runs. For example, you might want to make multiple changes to the code while comparing against the master branch. For this situation, Criterion.rs supports custom baselines. + +``` + git checkout main + cargo bench --bench sql_planner -- --save-baseline main + git checkout YOUR_BRANCH + cargo bench --bench sql_planner -- --baseline main +``` + +Note: For MacOS it may be required to run `cargo bench` with `sudo` + +``` +sudo cargo bench ... +``` + +More information on [Baselines](https://bheisler.github.io/criterion.rs/book/user_guide/command_line_options.html#baselines) + +### Upstream Benchmark Suites + +Instructions and tooling for running upstream benchmark suites against DataFusion can be found in [benchmarks](https://github.com/apache/datafusion/tree/main/benchmarks). + +These are valuable for comparative evaluation against alternative Arrow implementations and query engines. diff --git a/docs/source/index.rst b/docs/source/index.rst index 5d6dcd3f87a2..77412e716271 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -111,13 +111,16 @@ Please see the `developer’s guide`_ for contributing and `communication`_ for contributor-guide/index contributor-guide/communication + contributor-guide/getting_started contributor-guide/architecture + contributor-guide/testing + contributor-guide/howtos contributor-guide/roadmap contributor-guide/quarterly_roadmap contributor-guide/governance contributor-guide/specification/index -.. _toc.contributor-guide: +.. _toc.subprojects: .. toctree:: :maxdepth: 1