Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate make_scalar_function #8878

Merged
merged 11 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ use arrow::compute::kernels::numeric::add;
use arrow_array::{ArrayRef, Float64Array, Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::*;
use datafusion::{
execution::registry::FunctionRegistry,
physical_plan::functions::make_scalar_function, test_util,
};
use datafusion::{execution::registry::FunctionRegistry, test_util};
use datafusion_common::cast::as_float64_array;
use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue};
use datafusion_expr::{
Expand Down Expand Up @@ -87,12 +84,18 @@ async fn scalar_udf() -> Result<()> {

ctx.register_batch("t", batch)?;

let myfunc = |args: &[ArrayRef]| {
let l = as_int32_array(&args[0])?;
let r = as_int32_array(&args[1])?;
Ok(Arc::new(add(l, r)?) as ArrayRef)
};
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(l) = &args[0] else {
panic!()
};
let ColumnarValue::Array(r) = &args[0] else {
panic!()
};

let l = as_int32_array(l)?;
let r = as_int32_array(r)?;
Ok(ColumnarValue::Array(Arc::new(add(l, r)?) as ArrayRef))
});

ctx.register_udf(create_udf(
"my_add",
Expand Down Expand Up @@ -163,11 +166,15 @@ async fn scalar_udf_zero_params() -> Result<()> {

ctx.register_batch("t", batch)?;
// create function just returns 100 regardless of inp
let myfunc = |args: &[ArrayRef]| {
let num_rows = args[0].len();
Ok(Arc::new((0..num_rows).map(|_| 100).collect::<Int32Array>()) as ArrayRef)
};
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!()
};
let num_rows = array.len();
Ok(ColumnarValue::Array(Arc::new(
(0..num_rows).map(|_| 100).collect::<Int32Array>(),
) as ArrayRef))
});

ctx.register_udf(create_udf(
"get_100",
Expand Down Expand Up @@ -307,8 +314,12 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
ctx.register_batch("t", batch).unwrap();

let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!()
};
Ok(ColumnarValue::Array(Arc::clone(array)))
});

ctx.register_udf(create_udf(
"MY_FUNC",
Expand Down Expand Up @@ -348,8 +359,12 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
ctx.register_batch("t", batch).unwrap();

let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!()
};
Ok(ColumnarValue::Array(Arc::clone(array)))
});

let udf = create_udf(
"dummy",
Expand Down
22 changes: 13 additions & 9 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1321,9 +1321,7 @@ mod tests {
assert_contains, cast::as_int32_array, plan_datafusion_err, DFField, ToDFSchema,
};
use datafusion_expr::{interval_arithmetic::Interval, *};
use datafusion_physical_expr::{
execution_props::ExecutionProps, functions::make_scalar_function,
};
use datafusion_physical_expr::execution_props::ExecutionProps;

use chrono::{DateTime, TimeZone, Utc};

Expand Down Expand Up @@ -1438,9 +1436,16 @@ mod tests {
let input_types = vec![DataType::Int32, DataType::Int32];
let return_type = Arc::new(DataType::Int32);

let fun = |args: &[ArrayRef]| {
let arg0 = as_int32_array(&args[0])?;
let arg1 = as_int32_array(&args[1])?;
let fun = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(arg0) = &args[0] else {
panic!()
};
let ColumnarValue::Array(arg1) = &args[1] else {
panic!()
};

let arg0 = as_int32_array(arg0)?;
let arg1 = as_int32_array(&arg1)?;

// 2. perform the computation
let array = arg0
Expand All @@ -1456,10 +1461,9 @@ mod tests {
})
.collect::<Int32Array>();

Ok(Arc::new(array) as ArrayRef)
};
Ok(ColumnarValue::Array(Arc::new(array) as ArrayRef))
});

let fun = make_scalar_function(fun);
Arc::new(create_udf(
"udf_add",
input_types,
Expand Down
6 changes: 4 additions & 2 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,11 @@ pub(crate) enum Hint {
AcceptsSingular,
}

/// decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function
/// Decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function
/// and vice-versa after evaluation.
pub fn make_scalar_function<F>(inner: F) -> ScalarFunctionImplementation
/// Note that this function makes a scalar function with no arguments or all scalar inputs return a scalar.
/// That's said its output will be same for all input rows in a batch.
pub(crate) fn make_scalar_function<F>(inner: F) -> ScalarFunctionImplementation
viirya marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member Author

@viirya viirya Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we should deprecate it and ask user to use ScalarUDFImpl instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a behavior change, so I add a label for it

where
F: Fn(&[ArrayRef]) -> Result<ArrayRef> + Sync + Send + 'static,
{
Expand Down
3 changes: 1 addition & 2 deletions datafusion/proto/src/bytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use crate::physical_plan::{
AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
};
use crate::protobuf;
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion_common::{plan_datafusion_err, DataFusionError, Result};
use datafusion_expr::{
create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LogicalPlan, Volatility,
Expand Down Expand Up @@ -117,7 +116,7 @@ impl Serializeable for Expr {
vec![],
Arc::new(arrow::datatypes::DataType::Null),
Volatility::Immutable,
make_scalar_function(|_| unimplemented!()),
Arc::new(|_| unimplemented!()),
)))
}

Expand Down
16 changes: 9 additions & 7 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::parquet::file::properties::{WriterProperties, WriterVersion};
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext};
use datafusion::test_util::{TestTableFactory, TestTableProvider};
use datafusion_common::file_options::csv_writer::CsvWriterOptions;
Expand All @@ -53,9 +52,9 @@ use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore};
use datafusion_expr::{
col, create_udaf, lit, Accumulator, AggregateFunction,
BuiltinScalarFunction::{Sqrt, Substr},
Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility,
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF,
WindowUDFImpl,
ColumnarValue, Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast,
Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition, WindowUDF, WindowUDFImpl,
};
use datafusion_proto::bytes::{
logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec,
Expand Down Expand Up @@ -1592,9 +1591,12 @@ fn roundtrip_aggregate_udf() {

#[test]
fn roundtrip_scalar_udf() {
let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef);

let scalar_fn = make_scalar_function(fn_impl);
let scalar_fn = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!()
};
Ok(ColumnarValue::Array(Arc::new(array.clone()) as ArrayRef))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could make this code easier to work with using From impls, like

Suggested change
Ok(ColumnarValue::Array(Arc::new(array.clone()) as ArrayRef))
Ok(ColumnarValue::from(Arc::new(array.clone()) as ArrayRef))

Maybe it doesn't matter 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I added two From impls for ColumnarValue.

});

let udf = create_udf(
"dummy",
Expand Down
14 changes: 8 additions & 6 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ use datafusion::physical_plan::expressions::{
GetFieldAccessExpr, GetIndexedFieldExpr, NotExpr, NthValue, PhysicalSortExpr, Sum,
};
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::physical_plan::insert::FileSinkExec;
use datafusion::physical_plan::joins::{
HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode,
Expand All @@ -73,8 +72,8 @@ use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::stats::Precision;
use datafusion_common::{FileTypeWriterOptions, Result};
use datafusion_expr::{
Accumulator, AccumulatorFactoryFunction, AggregateUDF, Signature, SimpleAggregateUDF,
WindowFrame, WindowFrameBound,
Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, Signature,
SimpleAggregateUDF, WindowFrame, WindowFrameBound,
};
use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec};
use datafusion_proto::protobuf;
Expand Down Expand Up @@ -568,9 +567,12 @@ fn roundtrip_scalar_udf() -> Result<()> {

let input = Arc::new(EmptyExec::new(schema.clone()));

let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef);

let scalar_fn = make_scalar_function(fn_impl);
let scalar_fn = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!()
};
Ok(ColumnarValue::Array(Arc::new(array.clone()) as ArrayRef))
});

let udf = create_udf(
"dummy",
Expand Down
12 changes: 7 additions & 5 deletions datafusion/proto/tests/cases/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ use arrow::array::ArrayRef;
use arrow::datatypes::DataType;

use datafusion::execution::FunctionRegistry;
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::prelude::SessionContext;
use datafusion_expr::{col, create_udf, lit};
use datafusion_expr::{col, create_udf, lit, ColumnarValue};
use datafusion_expr::{Expr, Volatility};
use datafusion_proto::bytes::Serializeable;

Expand Down Expand Up @@ -226,9 +225,12 @@ fn roundtrip_deeply_nested() {

/// return a `SessionContext` with a `dummy` function registered as a UDF
fn context_with_udf() -> SessionContext {
let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef);

let scalar_fn = make_scalar_function(fn_impl);
let scalar_fn = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!()
};
Ok(ColumnarValue::Array(Arc::new(array.clone()) as ArrayRef))
});

let udf = create_udf(
"dummy",
Expand Down
18 changes: 12 additions & 6 deletions datafusion/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ use arrow::array::{
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
use arrow::record_batch::RecordBatch;
use datafusion::execution::context::SessionState;
use datafusion::logical_expr::{create_udf, Expr, ScalarUDF, Volatility};
use datafusion::physical_expr::functions::make_scalar_function;
use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarUDF, Volatility};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionConfig;
use datafusion::{
Expand Down Expand Up @@ -356,9 +355,16 @@ pub async fn register_metadata_tables(ctx: &SessionContext) {
/// Create a UDF function named "example". See the `sample_udf.rs` example
/// file for an explanation of the API.
fn create_example_udf() -> ScalarUDF {
let adder = make_scalar_function(|args: &[ArrayRef]| {
let lhs = as_float64_array(&args[0]).expect("cast failed");
let rhs = as_float64_array(&args[1]).expect("cast failed");
let adder = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(lhs) = &args[0] else {
panic!()
};
let ColumnarValue::Array(rhs) = &args[1] else {
panic!()
};

let lhs = as_float64_array(lhs).expect("cast failed");
let rhs = as_float64_array(rhs).expect("cast failed");
let array = lhs
.iter()
.zip(rhs.iter())
Expand All @@ -367,7 +373,7 @@ fn create_example_udf() -> ScalarUDF {
_ => None,
})
.collect::<Float64Array>();
Ok(Arc::new(array) as ArrayRef)
Ok(ColumnarValue::Array(Arc::new(array) as ArrayRef))
});
create_udf(
"example",
Expand Down
Loading