From 256fa220a98da149dee7c69268b5cf680d83be95 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 15 Jan 2024 12:54:29 -0800 Subject: [PATCH 01/10] Make make_scalar_function private --- .../user_defined_scalar_functions.rs | 53 ++++++++++++------- .../simplify_expressions/expr_simplifier.rs | 22 ++++---- datafusion/physical-expr/src/functions.rs | 6 ++- datafusion/proto/src/bytes/mod.rs | 3 +- .../tests/cases/roundtrip_logical_plan.rs | 16 +++--- .../tests/cases/roundtrip_physical_plan.rs | 14 ++--- datafusion/proto/tests/cases/serialize.rs | 12 +++-- datafusion/sqllogictest/src/test_context.rs | 18 ++++--- 8 files changed, 88 insertions(+), 56 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index fe88ea6cf115..cc7efd2c7c2e 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -19,10 +19,7 @@ use arrow::compute::kernels::numeric::add; use arrow_array::{ArrayRef, Float64Array, Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; use datafusion::prelude::*; -use datafusion::{ - execution::registry::FunctionRegistry, - physical_plan::functions::make_scalar_function, test_util, -}; +use datafusion::{execution::registry::FunctionRegistry, test_util}; use datafusion_common::cast::as_float64_array; use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue}; use datafusion_expr::{ @@ -87,12 +84,18 @@ async fn scalar_udf() -> Result<()> { ctx.register_batch("t", batch)?; - let myfunc = |args: &[ArrayRef]| { - let l = as_int32_array(&args[0])?; - let r = as_int32_array(&args[1])?; - Ok(Arc::new(add(l, r)?) as ArrayRef) - }; - let myfunc = make_scalar_function(myfunc); + let myfunc = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(l) = &args[0] else { + panic!() + }; + let ColumnarValue::Array(r) = &args[0] else { + panic!() + }; + + let l = as_int32_array(l)?; + let r = as_int32_array(r)?; + Ok(ColumnarValue::Array(Arc::new(add(l, r)?) as ArrayRef)) + }); ctx.register_udf(create_udf( "my_add", @@ -163,11 +166,15 @@ async fn scalar_udf_zero_params() -> Result<()> { ctx.register_batch("t", batch)?; // create function just returns 100 regardless of inp - let myfunc = |args: &[ArrayRef]| { - let num_rows = args[0].len(); - Ok(Arc::new((0..num_rows).map(|_| 100).collect::()) as ArrayRef) - }; - let myfunc = make_scalar_function(myfunc); + let myfunc = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(array) = &args[0] else { + panic!() + }; + let num_rows = array.len(); + Ok(ColumnarValue::Array(Arc::new( + (0..num_rows).map(|_| 100).collect::(), + ) as ArrayRef)) + }); ctx.register_udf(create_udf( "get_100", @@ -307,8 +314,12 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; ctx.register_batch("t", batch).unwrap(); - let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); - let myfunc = make_scalar_function(myfunc); + let myfunc = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(array) = &args[0] else { + panic!() + }; + Ok(ColumnarValue::Array(Arc::clone(array))) + }); ctx.register_udf(create_udf( "MY_FUNC", @@ -348,8 +359,12 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; ctx.register_batch("t", batch).unwrap(); - let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); - let myfunc = make_scalar_function(myfunc); + let myfunc = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(array) = &args[0] else { + panic!() + }; + Ok(ColumnarValue::Array(Arc::clone(array))) + }); let udf = create_udf( "dummy", diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 3ba343003e33..aede86de151f 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1321,9 +1321,7 @@ mod tests { assert_contains, cast::as_int32_array, plan_datafusion_err, DFField, ToDFSchema, }; use datafusion_expr::{interval_arithmetic::Interval, *}; - use datafusion_physical_expr::{ - execution_props::ExecutionProps, functions::make_scalar_function, - }; + use datafusion_physical_expr::execution_props::ExecutionProps; use chrono::{DateTime, TimeZone, Utc}; @@ -1438,9 +1436,16 @@ mod tests { let input_types = vec![DataType::Int32, DataType::Int32]; let return_type = Arc::new(DataType::Int32); - let fun = |args: &[ArrayRef]| { - let arg0 = as_int32_array(&args[0])?; - let arg1 = as_int32_array(&args[1])?; + let fun = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(arg0) = &args[0] else { + panic!() + }; + let ColumnarValue::Array(arg1) = &args[1] else { + panic!() + }; + + let arg0 = as_int32_array(arg0)?; + let arg1 = as_int32_array(&arg1)?; // 2. perform the computation let array = arg0 @@ -1456,10 +1461,9 @@ mod tests { }) .collect::(); - Ok(Arc::new(array) as ArrayRef) - }; + Ok(ColumnarValue::Array(Arc::new(array) as ArrayRef)) + }); - let fun = make_scalar_function(fun); Arc::new(create_udf( "udf_add", input_types, diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 66e22d2302de..e811d41886c0 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -191,9 +191,11 @@ pub(crate) enum Hint { AcceptsSingular, } -/// decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function +/// Decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function /// and vice-versa after evaluation. -pub fn make_scalar_function(inner: F) -> ScalarFunctionImplementation +/// Note that this function makes a scalar function with no arguments or all scalar inputs return a scalar. +/// That's said its output will be same for all input rows in a batch. +pub(crate) fn make_scalar_function(inner: F) -> ScalarFunctionImplementation where F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, { diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 9377501499e2..d9eda5d00d52 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -23,7 +23,6 @@ use crate::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; use crate::protobuf; -use datafusion::physical_plan::functions::make_scalar_function; use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; use datafusion_expr::{ create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LogicalPlan, Volatility, @@ -117,7 +116,7 @@ impl Serializeable for Expr { vec![], Arc::new(arrow::datatypes::DataType::Null), Volatility::Immutable, - make_scalar_function(|_| unimplemented!()), + Arc::new(|_| unimplemented!()), ))) } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index ed21124a9e22..babb80de164b 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -34,7 +34,6 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; -use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext}; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::file_options::csv_writer::CsvWriterOptions; @@ -53,9 +52,9 @@ use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::{Sqrt, Substr}, - Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, - WindowUDFImpl, + ColumnarValue, Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, + Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -1592,9 +1591,12 @@ fn roundtrip_aggregate_udf() { #[test] fn roundtrip_scalar_udf() { - let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); - - let scalar_fn = make_scalar_function(fn_impl); + let scalar_fn = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(array) = &args[0] else { + panic!() + }; + Ok(ColumnarValue::Array(Arc::new(array.clone()) as ArrayRef)) + }); let udf = create_udf( "dummy", diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 3a13dc887f0c..1bda93bc34ae 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -47,7 +47,6 @@ use datafusion::physical_plan::expressions::{ GetFieldAccessExpr, GetIndexedFieldExpr, NotExpr, NthValue, PhysicalSortExpr, Sum, }; use datafusion::physical_plan::filter::FilterExec; -use datafusion::physical_plan::functions::make_scalar_function; use datafusion::physical_plan::insert::FileSinkExec; use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, @@ -73,8 +72,8 @@ use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{FileTypeWriterOptions, Result}; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, Signature, SimpleAggregateUDF, - WindowFrame, WindowFrameBound, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, Signature, + SimpleAggregateUDF, WindowFrame, WindowFrameBound, }; use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; @@ -568,9 +567,12 @@ fn roundtrip_scalar_udf() -> Result<()> { let input = Arc::new(EmptyExec::new(schema.clone())); - let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); - - let scalar_fn = make_scalar_function(fn_impl); + let scalar_fn = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(array) = &args[0] else { + panic!() + }; + Ok(ColumnarValue::Array(Arc::new(array.clone()) as ArrayRef)) + }); let udf = create_udf( "dummy", diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index 5b890accd81f..965179dccefb 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -21,9 +21,8 @@ use arrow::array::ArrayRef; use arrow::datatypes::DataType; use datafusion::execution::FunctionRegistry; -use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::SessionContext; -use datafusion_expr::{col, create_udf, lit}; +use datafusion_expr::{col, create_udf, lit, ColumnarValue}; use datafusion_expr::{Expr, Volatility}; use datafusion_proto::bytes::Serializeable; @@ -226,9 +225,12 @@ fn roundtrip_deeply_nested() { /// return a `SessionContext` with a `dummy` function registered as a UDF fn context_with_udf() -> SessionContext { - let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); - - let scalar_fn = make_scalar_function(fn_impl); + let scalar_fn = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(array) = &args[0] else { + panic!() + }; + Ok(ColumnarValue::Array(Arc::new(array.clone()) as ArrayRef)) + }); let udf = create_udf( "dummy", diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index a5ce7ccb9fe0..f090769a46c4 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -28,8 +28,7 @@ use arrow::array::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use arrow::record_batch::RecordBatch; use datafusion::execution::context::SessionState; -use datafusion::logical_expr::{create_udf, Expr, ScalarUDF, Volatility}; -use datafusion::physical_expr::functions::make_scalar_function; +use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarUDF, Volatility}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionConfig; use datafusion::{ @@ -356,9 +355,16 @@ pub async fn register_metadata_tables(ctx: &SessionContext) { /// Create a UDF function named "example". See the `sample_udf.rs` example /// file for an explanation of the API. fn create_example_udf() -> ScalarUDF { - let adder = make_scalar_function(|args: &[ArrayRef]| { - let lhs = as_float64_array(&args[0]).expect("cast failed"); - let rhs = as_float64_array(&args[1]).expect("cast failed"); + let adder = Arc::new(|args: &[ColumnarValue]| { + let ColumnarValue::Array(lhs) = &args[0] else { + panic!() + }; + let ColumnarValue::Array(rhs) = &args[1] else { + panic!() + }; + + let lhs = as_float64_array(lhs).expect("cast failed"); + let rhs = as_float64_array(rhs).expect("cast failed"); let array = lhs .iter() .zip(rhs.iter()) @@ -367,7 +373,7 @@ fn create_example_udf() -> ScalarUDF { _ => None, }) .collect::(); - Ok(Arc::new(array) as ArrayRef) + Ok(ColumnarValue::Array(Arc::new(array) as ArrayRef)) }); create_udf( "example", From 7ebd75949a9f29c91b9ffb7639ca6701c5d8fcb2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 15 Jan 2024 13:51:21 -0800 Subject: [PATCH 02/10] More --- datafusion-examples/examples/simple_udf.rs | 23 +++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index 39e1e13ce39a..e500633a984e 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -24,9 +24,10 @@ use datafusion::{ logical_expr::Volatility, }; +use datafusion::error::Result; use datafusion::prelude::*; -use datafusion::{error::Result, physical_plan::functions::make_scalar_function}; use datafusion_common::cast::as_float64_array; +use datafusion_expr::ColumnarValue; use std::sync::Arc; /// create local execution context with an in-memory table: @@ -61,7 +62,7 @@ async fn main() -> Result<()> { let ctx = create_context()?; // First, declare the actual implementation of the calculation - let pow = |args: &[ArrayRef]| { + let pow = Arc::new(|args: &[ColumnarValue]| { // in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to: // 1. cast the values to the type we want // 2. perform the computation for every element in the array (using a loop or SIMD) and construct the result @@ -69,9 +70,16 @@ async fn main() -> Result<()> { // this is guaranteed by DataFusion based on the function's signature. assert_eq!(args.len(), 2); + let ColumnarValue::Array(arg0) = &args[0] else { + panic!("should be array") + }; + let ColumnarValue::Array(arg1) = &args[0] else { + panic!("should be array") + }; + // 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics! - let base = as_float64_array(&args[0]).expect("cast failed"); - let exponent = as_float64_array(&args[1]).expect("cast failed"); + let base = as_float64_array(arg0).expect("cast failed"); + let exponent = as_float64_array(arg1).expect("cast failed"); // this is guaranteed by DataFusion. We place it just to make it obvious. assert_eq!(exponent.len(), base.len()); @@ -92,11 +100,8 @@ async fn main() -> Result<()> { // `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!) // `Arc` because arrays are immutable, thread-safe, trait objects. - Ok(Arc::new(array) as ArrayRef) - }; - // the function above expects an `ArrayRef`, but DataFusion may pass a scalar to a UDF. - // thus, we use `make_scalar_function` to decorare the closure so that it can handle both Arrays and Scalar values. - let pow = make_scalar_function(pow); + Ok(ColumnarValue::Array(Arc::new(array) as ArrayRef)) + }); // Next: // * give it a name so that it shows nicely when the plan is printed From 88fdc2cd2c94d62760679eceb2751d91b5acaa6c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 15 Jan 2024 14:38:45 -0800 Subject: [PATCH 03/10] More --- .../user_defined/user_defined_scalar_functions.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index cc7efd2c7c2e..ced6b331e7b5 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -167,13 +167,12 @@ async fn scalar_udf_zero_params() -> Result<()> { ctx.register_batch("t", batch)?; // create function just returns 100 regardless of inp let myfunc = Arc::new(|args: &[ColumnarValue]| { - let ColumnarValue::Array(array) = &args[0] else { - panic!() + let ColumnarValue::Scalar(_) = &args[0] else { + panic!("expect scalar") }; - let num_rows = array.len(); - Ok(ColumnarValue::Array(Arc::new( - (0..num_rows).map(|_| 100).collect::(), - ) as ArrayRef)) + Ok(ColumnarValue::Array( + Arc::new((0..1).map(|_| 100).collect::()) as ArrayRef, + )) }); ctx.register_udf(create_udf( From 1336a36b6e47ebfb73add9203344a7c4f401421d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 15 Jan 2024 14:56:02 -0800 Subject: [PATCH 04/10] Fix --- .../core/tests/user_defined/user_defined_scalar_functions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index ced6b331e7b5..7291cc199894 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -88,7 +88,7 @@ async fn scalar_udf() -> Result<()> { let ColumnarValue::Array(l) = &args[0] else { panic!() }; - let ColumnarValue::Array(r) = &args[0] else { + let ColumnarValue::Array(r) = &args[1] else { panic!() }; From 5399f6c50ff688d4dffab2978e1866159f7fce30 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 15 Jan 2024 15:39:12 -0800 Subject: [PATCH 05/10] More --- .../simplify_expressions/expr_simplifier.rs | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index aede86de151f..d96be680766a 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1437,14 +1437,29 @@ mod tests { let return_type = Arc::new(DataType::Int32); let fun = Arc::new(|args: &[ColumnarValue]| { - let ColumnarValue::Array(arg0) = &args[0] else { - panic!() + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let inferred_length = len.unwrap_or(1); + + let arg0 = match &args[0] { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => { + scalar.to_array_of_size(inferred_length).unwrap() + } }; - let ColumnarValue::Array(arg1) = &args[1] else { - panic!() + let arg1 = match &args[1] { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => { + scalar.to_array_of_size(inferred_length).unwrap() + } }; - let arg0 = as_int32_array(arg0)?; + let arg0 = as_int32_array(&arg0)?; let arg1 = as_int32_array(&arg1)?; // 2. perform the computation From 25e89c23fa36a93f88c3f2fb50c230b1afb4f229 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 16 Jan 2024 23:27:45 -0800 Subject: [PATCH 06/10] Update datafusion/physical-expr/src/functions.rs Co-authored-by: Andrew Lamb --- datafusion/physical-expr/src/functions.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index e811d41886c0..ecf6c84d5c87 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -195,7 +195,14 @@ pub(crate) enum Hint { /// and vice-versa after evaluation. /// Note that this function makes a scalar function with no arguments or all scalar inputs return a scalar. /// That's said its output will be same for all input rows in a batch. -pub(crate) fn make_scalar_function(inner: F) -> ScalarFunctionImplementation +#[deprecated(since = "35.0.0", note = "Implement your function directly in terms of ColumnarValue or use `ScalarUDF` instead")] +pub fn make_scalar_function(inner: F) -> ScalarFunctionImplementation + make_scaler_function_inner(inner) +} + +/// Internal implementation, see comments on `make_scalar_function` for caveats +pub(crate) fn make_scalar_function_inner(inner: F) -> ScalarFunctionImplementation + where F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, { From e06ec90ce304a2f7364421696dd4df5b6e4394d4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 16 Jan 2024 23:31:31 -0800 Subject: [PATCH 07/10] For review --- datafusion-examples/examples/simple_udf.rs | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index e500633a984e..a9907557ba14 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -70,16 +70,22 @@ async fn main() -> Result<()> { // this is guaranteed by DataFusion based on the function's signature. assert_eq!(args.len(), 2); - let ColumnarValue::Array(arg0) = &args[0] else { - panic!("should be array") - }; - let ColumnarValue::Array(arg1) = &args[0] else { - panic!("should be array") - }; + // Try to obtain row number + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let inferred_length = len.unwrap_or(1); + + let arg0 = args[0].into_array(inferred_length)?; + let arg1 = args[1].into_array(inferred_length)?; // 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics! - let base = as_float64_array(arg0).expect("cast failed"); - let exponent = as_float64_array(arg1).expect("cast failed"); + let base = as_float64_array(&arg0).expect("cast failed"); + let exponent = as_float64_array(&arg1).expect("cast failed"); // this is guaranteed by DataFusion. We place it just to make it obvious. assert_eq!(exponent.len(), base.len()); From 13578b4273eea549ebc26ba0fa8e0b4ccb5eaf01 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 17 Jan 2024 00:03:00 -0800 Subject: [PATCH 08/10] For review --- datafusion-examples/examples/simple_udf.rs | 2 +- .../user_defined_scalar_functions.rs | 14 +- datafusion/expr/src/columnar_value.rs | 12 + .../simplify_expressions/expr_simplifier.rs | 2 +- datafusion/physical-expr/src/functions.rs | 345 +++++++++--------- .../physical-expr/src/regex_expressions.rs | 6 +- .../tests/cases/roundtrip_logical_plan.rs | 4 +- .../tests/cases/roundtrip_physical_plan.rs | 4 +- datafusion/proto/tests/cases/serialize.rs | 4 +- datafusion/sqllogictest/src/test_context.rs | 6 +- 10 files changed, 209 insertions(+), 190 deletions(-) diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index a9907557ba14..88cdf59c702f 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -106,7 +106,7 @@ async fn main() -> Result<()> { // `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!) // `Arc` because arrays are immutable, thread-safe, trait objects. - Ok(ColumnarValue::Array(Arc::new(array) as ArrayRef)) + Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) }); // Next: diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 7291cc199894..b8573a690e7b 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -86,15 +86,15 @@ async fn scalar_udf() -> Result<()> { let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(l) = &args[0] else { - panic!() + panic!("should be array") }; let ColumnarValue::Array(r) = &args[1] else { - panic!() + panic!("should be array") }; let l = as_int32_array(l)?; let r = as_int32_array(r)?; - Ok(ColumnarValue::Array(Arc::new(add(l, r)?) as ArrayRef)) + Ok(ColumnarValue::from(Arc::new(add(l, r)?) as ArrayRef)) }); ctx.register_udf(create_udf( @@ -315,9 +315,9 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { - panic!() + panic!("should be array") }; - Ok(ColumnarValue::Array(Arc::clone(array))) + Ok(ColumnarValue::from(Arc::clone(array))) }); ctx.register_udf(create_udf( @@ -360,9 +360,9 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { - panic!() + panic!("should be array") }; - Ok(ColumnarValue::Array(Arc::clone(array))) + Ok(ColumnarValue::from(Arc::clone(array))) }); let udf = create_udf( diff --git a/datafusion/expr/src/columnar_value.rs b/datafusion/expr/src/columnar_value.rs index 7a2883928169..58c534b50aad 100644 --- a/datafusion/expr/src/columnar_value.rs +++ b/datafusion/expr/src/columnar_value.rs @@ -37,6 +37,18 @@ pub enum ColumnarValue { Scalar(ScalarValue), } +impl From for ColumnarValue { + fn from(value: ArrayRef) -> Self { + ColumnarValue::Array(value) + } +} + +impl From for ColumnarValue { + fn from(value: ScalarValue) -> Self { + ColumnarValue::Scalar(value) + } +} + impl ColumnarValue { pub fn data_type(&self) -> DataType { match self { diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index d96be680766a..674e85a55c92 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1476,7 +1476,7 @@ mod tests { }) .collect::(); - Ok(ColumnarValue::Array(Arc::new(array) as ArrayRef)) + Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) }); Arc::new(create_udf( diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index ecf6c84d5c87..3fefd40fd382 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -195,14 +195,19 @@ pub(crate) enum Hint { /// and vice-versa after evaluation. /// Note that this function makes a scalar function with no arguments or all scalar inputs return a scalar. /// That's said its output will be same for all input rows in a batch. -#[deprecated(since = "35.0.0", note = "Implement your function directly in terms of ColumnarValue or use `ScalarUDF` instead")] +#[deprecated( + since = "35.0.0", + note = "Implement your function directly in terms of ColumnarValue or use `ScalarUDF` instead" +)] pub fn make_scalar_function(inner: F) -> ScalarFunctionImplementation - make_scaler_function_inner(inner) +where + F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, +{ + make_scalar_function_inner(inner) } /// Internal implementation, see comments on `make_scalar_function` for caveats -pub(crate) fn make_scalar_function_inner(inner: F) -> ScalarFunctionImplementation - +pub(crate) fn make_scalar_function_inner(inner: F) -> ScalarFunctionImplementation where F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, { @@ -269,9 +274,9 @@ pub fn create_physical_fun( ) -> Result { Ok(match fun { // math functions - BuiltinScalarFunction::Abs => { - Arc::new(|args| make_scalar_function(math_expressions::abs_invoke)(args)) - } + BuiltinScalarFunction::Abs => Arc::new(|args| { + make_scalar_function_inner(math_expressions::abs_invoke)(args) + }), BuiltinScalarFunction::Acos => Arc::new(math_expressions::acos), BuiltinScalarFunction::Asin => Arc::new(math_expressions::asin), BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan), @@ -284,31 +289,31 @@ pub fn create_physical_fun( BuiltinScalarFunction::Degrees => Arc::new(math_expressions::to_degrees), BuiltinScalarFunction::Exp => Arc::new(math_expressions::exp), BuiltinScalarFunction::Factorial => { - Arc::new(|args| make_scalar_function(math_expressions::factorial)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::factorial)(args)) } BuiltinScalarFunction::Floor => Arc::new(math_expressions::floor), BuiltinScalarFunction::Gcd => { - Arc::new(|args| make_scalar_function(math_expressions::gcd)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::gcd)(args)) } BuiltinScalarFunction::Isnan => { - Arc::new(|args| make_scalar_function(math_expressions::isnan)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::isnan)(args)) } BuiltinScalarFunction::Iszero => { - Arc::new(|args| make_scalar_function(math_expressions::iszero)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::iszero)(args)) } BuiltinScalarFunction::Lcm => { - Arc::new(|args| make_scalar_function(math_expressions::lcm)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::lcm)(args)) } BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln), BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10), BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2), BuiltinScalarFunction::Nanvl => { - Arc::new(|args| make_scalar_function(math_expressions::nanvl)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::nanvl)(args)) } BuiltinScalarFunction::Radians => Arc::new(math_expressions::to_radians), BuiltinScalarFunction::Random => Arc::new(math_expressions::random), BuiltinScalarFunction::Round => { - Arc::new(|args| make_scalar_function(math_expressions::round)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::round)(args)) } BuiltinScalarFunction::Signum => Arc::new(math_expressions::signum), BuiltinScalarFunction::Sin => Arc::new(math_expressions::sin), @@ -318,135 +323,135 @@ pub fn create_physical_fun( BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan), BuiltinScalarFunction::Tanh => Arc::new(math_expressions::tanh), BuiltinScalarFunction::Trunc => { - Arc::new(|args| make_scalar_function(math_expressions::trunc)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::trunc)(args)) } BuiltinScalarFunction::Pi => Arc::new(math_expressions::pi), BuiltinScalarFunction::Power => { - Arc::new(|args| make_scalar_function(math_expressions::power)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::power)(args)) } BuiltinScalarFunction::Atan2 => { - Arc::new(|args| make_scalar_function(math_expressions::atan2)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::atan2)(args)) } BuiltinScalarFunction::Log => { - Arc::new(|args| make_scalar_function(math_expressions::log)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::log)(args)) } BuiltinScalarFunction::Cot => { - Arc::new(|args| make_scalar_function(math_expressions::cot)(args)) + Arc::new(|args| make_scalar_function_inner(math_expressions::cot)(args)) } // array functions - BuiltinScalarFunction::ArrayAppend => { - Arc::new(|args| make_scalar_function(array_expressions::array_append)(args)) - } - BuiltinScalarFunction::ArraySort => { - Arc::new(|args| make_scalar_function(array_expressions::array_sort)(args)) - } - BuiltinScalarFunction::ArrayConcat => { - Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args)) - } - BuiltinScalarFunction::ArrayEmpty => { - Arc::new(|args| make_scalar_function(array_expressions::array_empty)(args)) - } - BuiltinScalarFunction::ArrayHasAll => { - Arc::new(|args| make_scalar_function(array_expressions::array_has_all)(args)) - } - BuiltinScalarFunction::ArrayHasAny => { - Arc::new(|args| make_scalar_function(array_expressions::array_has_any)(args)) - } - BuiltinScalarFunction::ArrayHas => { - Arc::new(|args| make_scalar_function(array_expressions::array_has)(args)) - } - BuiltinScalarFunction::ArrayDims => { - Arc::new(|args| make_scalar_function(array_expressions::array_dims)(args)) - } - BuiltinScalarFunction::ArrayDistinct => { - Arc::new(|args| make_scalar_function(array_expressions::array_distinct)(args)) - } - BuiltinScalarFunction::ArrayElement => { - Arc::new(|args| make_scalar_function(array_expressions::array_element)(args)) - } - BuiltinScalarFunction::ArrayExcept => { - Arc::new(|args| make_scalar_function(array_expressions::array_except)(args)) - } - BuiltinScalarFunction::ArrayLength => { - Arc::new(|args| make_scalar_function(array_expressions::array_length)(args)) - } + BuiltinScalarFunction::ArrayAppend => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_append)(args) + }), + BuiltinScalarFunction::ArraySort => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_sort)(args) + }), + BuiltinScalarFunction::ArrayConcat => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_concat)(args) + }), + BuiltinScalarFunction::ArrayEmpty => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_empty)(args) + }), + BuiltinScalarFunction::ArrayHasAll => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_has_all)(args) + }), + BuiltinScalarFunction::ArrayHasAny => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_has_any)(args) + }), + BuiltinScalarFunction::ArrayHas => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_has)(args) + }), + BuiltinScalarFunction::ArrayDims => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_dims)(args) + }), + BuiltinScalarFunction::ArrayDistinct => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_distinct)(args) + }), + BuiltinScalarFunction::ArrayElement => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_element)(args) + }), + BuiltinScalarFunction::ArrayExcept => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_except)(args) + }), + BuiltinScalarFunction::ArrayLength => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_length)(args) + }), BuiltinScalarFunction::Flatten => { - Arc::new(|args| make_scalar_function(array_expressions::flatten)(args)) - } - BuiltinScalarFunction::ArrayNdims => { - Arc::new(|args| make_scalar_function(array_expressions::array_ndims)(args)) + Arc::new(|args| make_scalar_function_inner(array_expressions::flatten)(args)) } + BuiltinScalarFunction::ArrayNdims => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_ndims)(args) + }), BuiltinScalarFunction::ArrayPopFront => Arc::new(|args| { - make_scalar_function(array_expressions::array_pop_front)(args) + make_scalar_function_inner(array_expressions::array_pop_front)(args) + }), + BuiltinScalarFunction::ArrayPopBack => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_pop_back)(args) + }), + BuiltinScalarFunction::ArrayPosition => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_position)(args) }), - BuiltinScalarFunction::ArrayPopBack => { - Arc::new(|args| make_scalar_function(array_expressions::array_pop_back)(args)) - } - BuiltinScalarFunction::ArrayPosition => { - Arc::new(|args| make_scalar_function(array_expressions::array_position)(args)) - } BuiltinScalarFunction::ArrayPositions => Arc::new(|args| { - make_scalar_function(array_expressions::array_positions)(args) + make_scalar_function_inner(array_expressions::array_positions)(args) + }), + BuiltinScalarFunction::ArrayPrepend => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_prepend)(args) + }), + BuiltinScalarFunction::ArrayRepeat => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_repeat)(args) + }), + BuiltinScalarFunction::ArrayRemove => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_remove)(args) + }), + BuiltinScalarFunction::ArrayRemoveN => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_remove_n)(args) }), - BuiltinScalarFunction::ArrayPrepend => { - Arc::new(|args| make_scalar_function(array_expressions::array_prepend)(args)) - } - BuiltinScalarFunction::ArrayRepeat => { - Arc::new(|args| make_scalar_function(array_expressions::array_repeat)(args)) - } - BuiltinScalarFunction::ArrayRemove => { - Arc::new(|args| make_scalar_function(array_expressions::array_remove)(args)) - } - BuiltinScalarFunction::ArrayRemoveN => { - Arc::new(|args| make_scalar_function(array_expressions::array_remove_n)(args)) - } BuiltinScalarFunction::ArrayRemoveAll => Arc::new(|args| { - make_scalar_function(array_expressions::array_remove_all)(args) + make_scalar_function_inner(array_expressions::array_remove_all)(args) + }), + BuiltinScalarFunction::ArrayReplace => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_replace)(args) }), - BuiltinScalarFunction::ArrayReplace => { - Arc::new(|args| make_scalar_function(array_expressions::array_replace)(args)) - } BuiltinScalarFunction::ArrayReplaceN => Arc::new(|args| { - make_scalar_function(array_expressions::array_replace_n)(args) + make_scalar_function_inner(array_expressions::array_replace_n)(args) }), BuiltinScalarFunction::ArrayReplaceAll => Arc::new(|args| { - make_scalar_function(array_expressions::array_replace_all)(args) + make_scalar_function_inner(array_expressions::array_replace_all)(args) + }), + BuiltinScalarFunction::ArraySlice => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_slice)(args) }), - BuiltinScalarFunction::ArraySlice => { - Arc::new(|args| make_scalar_function(array_expressions::array_slice)(args)) - } BuiltinScalarFunction::ArrayToString => Arc::new(|args| { - make_scalar_function(array_expressions::array_to_string)(args) + make_scalar_function_inner(array_expressions::array_to_string)(args) }), BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| { - make_scalar_function(array_expressions::array_intersect)(args) + make_scalar_function_inner(array_expressions::array_intersect)(args) + }), + BuiltinScalarFunction::Range => Arc::new(|args| { + make_scalar_function_inner(array_expressions::gen_range)(args) + }), + BuiltinScalarFunction::Cardinality => Arc::new(|args| { + make_scalar_function_inner(array_expressions::cardinality)(args) + }), + BuiltinScalarFunction::ArrayResize => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_resize)(args) + }), + BuiltinScalarFunction::MakeArray => Arc::new(|args| { + make_scalar_function_inner(array_expressions::make_array)(args) + }), + BuiltinScalarFunction::ArrayUnion => Arc::new(|args| { + make_scalar_function_inner(array_expressions::array_union)(args) }), - BuiltinScalarFunction::Range => { - Arc::new(|args| make_scalar_function(array_expressions::gen_range)(args)) - } - BuiltinScalarFunction::Cardinality => { - Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args)) - } - BuiltinScalarFunction::ArrayResize => { - Arc::new(|args| make_scalar_function(array_expressions::array_resize)(args)) - } - BuiltinScalarFunction::MakeArray => { - Arc::new(|args| make_scalar_function(array_expressions::make_array)(args)) - } - BuiltinScalarFunction::ArrayUnion => { - Arc::new(|args| make_scalar_function(array_expressions::array_union)(args)) - } // struct functions BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr), // string functions BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::ascii::)(args) + make_scalar_function_inner(string_expressions::ascii::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::ascii::)(args) + make_scalar_function_inner(string_expressions::ascii::)(args) } other => internal_err!("Unsupported data type {other:?} for function ascii"), }), @@ -464,10 +469,10 @@ pub fn create_physical_fun( }), BuiltinScalarFunction::Btrim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::btrim::)(args) + make_scalar_function_inner(string_expressions::btrim::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::btrim::)(args) + make_scalar_function_inner(string_expressions::btrim::)(args) } other => internal_err!("Unsupported data type {other:?} for function btrim"), }), @@ -479,7 +484,7 @@ pub fn create_physical_fun( Int32Type, "character_length" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -487,7 +492,7 @@ pub fn create_physical_fun( Int64Type, "character_length" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!( "Unsupported data type {other:?} for function character_length" @@ -495,13 +500,13 @@ pub fn create_physical_fun( }) } BuiltinScalarFunction::Chr => { - Arc::new(|args| make_scalar_function(string_expressions::chr)(args)) + Arc::new(|args| make_scalar_function_inner(string_expressions::chr)(args)) } BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce), BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), - BuiltinScalarFunction::ConcatWithSeparator => { - Arc::new(|args| make_scalar_function(string_expressions::concat_ws)(args)) - } + BuiltinScalarFunction::ConcatWithSeparator => Arc::new(|args| { + make_scalar_function_inner(string_expressions::concat_ws)(args) + }), BuiltinScalarFunction::DatePart => Arc::new(datetime_expressions::date_part), BuiltinScalarFunction::DateTrunc => Arc::new(datetime_expressions::date_trunc), BuiltinScalarFunction::DateBin => Arc::new(datetime_expressions::date_bin), @@ -543,10 +548,10 @@ pub fn create_physical_fun( } BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::initcap::)(args) + make_scalar_function_inner(string_expressions::initcap::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::initcap::)(args) + make_scalar_function_inner(string_expressions::initcap::)(args) } other => { internal_err!("Unsupported data type {other:?} for function initcap") @@ -555,11 +560,11 @@ pub fn create_physical_fun( BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(left, i64, "left"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!("Unsupported data type {other:?} for function left"), }), @@ -567,20 +572,20 @@ pub fn create_physical_fun( BuiltinScalarFunction::Lpad => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(lpad, i32, "lpad"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(lpad, i64, "lpad"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!("Unsupported data type {other:?} for function lpad"), }), BuiltinScalarFunction::Ltrim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::ltrim::)(args) + make_scalar_function_inner(string_expressions::ltrim::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::ltrim::)(args) + make_scalar_function_inner(string_expressions::ltrim::)(args) } other => internal_err!("Unsupported data type {other:?} for function ltrim"), }), @@ -617,7 +622,7 @@ pub fn create_physical_fun( i32, "regexp_match" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_on_array_if_regex_expressions_feature_flag!( @@ -625,7 +630,7 @@ pub fn create_physical_fun( i64, "regexp_match" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!( "Unsupported data type {other:?} for function regexp_match" @@ -659,19 +664,19 @@ pub fn create_physical_fun( } BuiltinScalarFunction::Repeat => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::repeat::)(args) + make_scalar_function_inner(string_expressions::repeat::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::repeat::)(args) + make_scalar_function_inner(string_expressions::repeat::)(args) } other => internal_err!("Unsupported data type {other:?} for function repeat"), }), BuiltinScalarFunction::Replace => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::replace::)(args) + make_scalar_function_inner(string_expressions::replace::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::replace::)(args) + make_scalar_function_inner(string_expressions::replace::)(args) } other => { internal_err!("Unsupported data type {other:?} for function replace") @@ -681,12 +686,12 @@ pub fn create_physical_fun( DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(reverse, i32, "reverse"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(reverse, i64, "reverse"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => { internal_err!("Unsupported data type {other:?} for function reverse") @@ -696,32 +701,32 @@ pub fn create_physical_fun( DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(right, i32, "right"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(right, i64, "right"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!("Unsupported data type {other:?} for function right"), }), BuiltinScalarFunction::Rpad => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(rpad, i32, "rpad"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(rpad, i64, "rpad"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!("Unsupported data type {other:?} for function rpad"), }), BuiltinScalarFunction::Rtrim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::rtrim::)(args) + make_scalar_function_inner(string_expressions::rtrim::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::rtrim::)(args) + make_scalar_function_inner(string_expressions::rtrim::)(args) } other => internal_err!("Unsupported data type {other:?} for function rtrim"), }), @@ -739,10 +744,10 @@ pub fn create_physical_fun( } BuiltinScalarFunction::SplitPart => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::split_part::)(args) + make_scalar_function_inner(string_expressions::split_part::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::split_part::)(args) + make_scalar_function_inner(string_expressions::split_part::)(args) } other => { internal_err!("Unsupported data type {other:?} for function split_part") @@ -750,12 +755,12 @@ pub fn create_physical_fun( }), BuiltinScalarFunction::StringToArray => { Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(array_expressions::string_to_array::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(array_expressions::string_to_array::)(args) - } + DataType::Utf8 => make_scalar_function_inner( + array_expressions::string_to_array::, + )(args), + DataType::LargeUtf8 => make_scalar_function_inner( + array_expressions::string_to_array::, + )(args), other => { internal_err!( "Unsupported data type {other:?} for function string_to_array" @@ -765,10 +770,10 @@ pub fn create_physical_fun( } BuiltinScalarFunction::StartsWith => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::starts_with::)(args) + make_scalar_function_inner(string_expressions::starts_with::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::starts_with::)(args) + make_scalar_function_inner(string_expressions::starts_with::)(args) } other => { internal_err!("Unsupported data type {other:?} for function starts_with") @@ -779,13 +784,13 @@ pub fn create_physical_fun( let func = invoke_if_unicode_expressions_feature_flag!( strpos, Int32Type, "strpos" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!( strpos, Int64Type, "strpos" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!("Unsupported data type {other:?} for function strpos"), }), @@ -793,21 +798,21 @@ pub fn create_physical_fun( DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(substr, i32, "substr"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(substr, i64, "substr"); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => internal_err!("Unsupported data type {other:?} for function substr"), }), BuiltinScalarFunction::ToHex => Arc::new(|args| match args[0].data_type() { DataType::Int32 => { - make_scalar_function(string_expressions::to_hex::)(args) + make_scalar_function_inner(string_expressions::to_hex::)(args) } DataType::Int64 => { - make_scalar_function(string_expressions::to_hex::)(args) + make_scalar_function_inner(string_expressions::to_hex::)(args) } other => internal_err!("Unsupported data type {other:?} for function to_hex"), }), @@ -818,7 +823,7 @@ pub fn create_physical_fun( i32, "translate" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -826,7 +831,7 @@ pub fn create_physical_fun( i64, "translate" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => { internal_err!("Unsupported data type {other:?} for function translate") @@ -834,10 +839,10 @@ pub fn create_physical_fun( }), BuiltinScalarFunction::Trim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::btrim::)(args) + make_scalar_function_inner(string_expressions::btrim::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::btrim::)(args) + make_scalar_function_inner(string_expressions::btrim::)(args) } other => internal_err!("Unsupported data type {other:?} for function trim"), }), @@ -858,10 +863,10 @@ pub fn create_physical_fun( }), BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::overlay::)(args) + make_scalar_function_inner(string_expressions::overlay::)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::overlay::)(args) + make_scalar_function_inner(string_expressions::overlay::)(args) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {other:?} for function overlay", @@ -869,12 +874,12 @@ pub fn create_physical_fun( }), BuiltinScalarFunction::Levenshtein => { Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::levenshtein::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::levenshtein::)(args) - } + DataType::Utf8 => make_scalar_function_inner( + string_expressions::levenshtein::, + )(args), + DataType::LargeUtf8 => make_scalar_function_inner( + string_expressions::levenshtein::, + )(args), other => Err(DataFusionError::Internal(format!( "Unsupported data type {other:?} for function levenshtein", ))), @@ -888,7 +893,7 @@ pub fn create_physical_fun( i32, "substr_index" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -896,7 +901,7 @@ pub fn create_physical_fun( i64, "substr_index" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {other:?} for function substr_index", @@ -910,7 +915,7 @@ pub fn create_physical_fun( Int32Type, "find_in_set" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -918,7 +923,7 @@ pub fn create_physical_fun( Int64Type, "find_in_set" ); - make_scalar_function(func)(args) + make_scalar_function_inner(func)(args) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {other:?} for function find_in_set", @@ -3117,7 +3122,7 @@ mod tests { #[test] fn test_make_scalar_function() -> Result<()> { - let adapter_func = make_scalar_function(dummy_function); + let adapter_func = make_scalar_function_inner(dummy_function); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); let array_arg = ColumnarValue::Array( diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs index b778fd86c24b..bdd272563e75 100644 --- a/datafusion/physical-expr/src/regex_expressions.rs +++ b/datafusion/physical-expr/src/regex_expressions.rs @@ -36,7 +36,9 @@ use hashbrown::HashMap; use regex::Regex; use std::sync::{Arc, OnceLock}; -use crate::functions::{make_scalar_function, make_scalar_function_with_hints, Hint}; +use crate::functions::{ + make_scalar_function_inner, make_scalar_function_with_hints, Hint, +}; /// Get the first argument from the given string array. /// @@ -401,7 +403,7 @@ pub fn specialize_regexp_replace( // If there are no specialized implementations, we'll fall back to the // generic implementation. - (_, _, _, _) => Ok(make_scalar_function(regexp_replace::)), + (_, _, _, _) => Ok(make_scalar_function_inner(regexp_replace::)), } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index babb80de164b..0e78aa3fc98e 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1593,9 +1593,9 @@ fn roundtrip_aggregate_udf() { fn roundtrip_scalar_udf() { let scalar_fn = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { - panic!() + panic!("should be array") }; - Ok(ColumnarValue::Array(Arc::new(array.clone()) as ArrayRef)) + Ok(ColumnarValue::from(Arc::new(array.clone()) as ArrayRef)) }); let udf = create_udf( diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 1bda93bc34ae..1bc10438b03e 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -569,9 +569,9 @@ fn roundtrip_scalar_udf() -> Result<()> { let scalar_fn = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { - panic!() + panic!("should be array") }; - Ok(ColumnarValue::Array(Arc::new(array.clone()) as ArrayRef)) + Ok(ColumnarValue::from(Arc::new(array.clone()) as ArrayRef)) }); let udf = create_udf( diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index 965179dccefb..d0359747b4e0 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -227,9 +227,9 @@ fn roundtrip_deeply_nested() { fn context_with_udf() -> SessionContext { let scalar_fn = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { - panic!() + panic!("should be array") }; - Ok(ColumnarValue::Array(Arc::new(array.clone()) as ArrayRef)) + Ok(ColumnarValue::from(Arc::new(array.clone()) as ArrayRef)) }); let udf = create_udf( diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index f090769a46c4..889ccdcd66d4 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -357,10 +357,10 @@ pub async fn register_metadata_tables(ctx: &SessionContext) { fn create_example_udf() -> ScalarUDF { let adder = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(lhs) = &args[0] else { - panic!() + panic!("should be array") }; let ColumnarValue::Array(rhs) = &args[1] else { - panic!() + panic!("should be array") }; let lhs = as_float64_array(lhs).expect("cast failed"); @@ -373,7 +373,7 @@ fn create_example_udf() -> ScalarUDF { _ => None, }) .collect::(); - Ok(ColumnarValue::Array(Arc::new(array) as ArrayRef)) + Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) }); create_udf( "example", From d7b117ab8d5a66ddaadf0481a68b3a2f2ad17d4b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 17 Jan 2024 00:11:25 -0800 Subject: [PATCH 09/10] Fix --- datafusion-examples/examples/simple_udf.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index 88cdf59c702f..491fac272c2c 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -80,8 +80,8 @@ async fn main() -> Result<()> { let inferred_length = len.unwrap_or(1); - let arg0 = args[0].into_array(inferred_length)?; - let arg1 = args[1].into_array(inferred_length)?; + let arg0 = args[0].clone().into_array(inferred_length)?; + let arg1 = args[1].clone().into_array(inferred_length)?; // 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics! let base = as_float64_array(&arg0).expect("cast failed"); From 64a7642ba987f293eafe9f14a0084a3ea04b9e5c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 19 Jan 2024 16:15:58 -0800 Subject: [PATCH 10/10] Update deprecated since tag --- datafusion/physical-expr/src/functions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 3fefd40fd382..d1e75bfe4f56 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -196,7 +196,7 @@ pub(crate) enum Hint { /// Note that this function makes a scalar function with no arguments or all scalar inputs return a scalar. /// That's said its output will be same for all input rows in a batch. #[deprecated( - since = "35.0.0", + since = "36.0.0", note = "Implement your function directly in terms of ColumnarValue or use `ScalarUDF` instead" )] pub fn make_scalar_function(inner: F) -> ScalarFunctionImplementation