Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate make_scalar_function #8878

Merged
merged 11 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions datafusion-examples/examples/simple_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ use datafusion::{
logical_expr::Volatility,
};

use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion::{error::Result, physical_plan::functions::make_scalar_function};
use datafusion_common::cast::as_float64_array;
use datafusion_expr::ColumnarValue;
use std::sync::Arc;

/// create local execution context with an in-memory table:
Expand Down Expand Up @@ -61,17 +62,30 @@ async fn main() -> Result<()> {
let ctx = create_context()?;

// First, declare the actual implementation of the calculation
let pow = |args: &[ArrayRef]| {
let pow = Arc::new(|args: &[ColumnarValue]| {
// in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to:
// 1. cast the values to the type we want
// 2. perform the computation for every element in the array (using a loop or SIMD) and construct the result

// this is guaranteed by DataFusion based on the function's signature.
assert_eq!(args.len(), 2);

// Try to obtain row number
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let inferred_length = len.unwrap_or(1);

let arg0 = args[0].clone().into_array(inferred_length)?;
let arg1 = args[1].clone().into_array(inferred_length)?;

// 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics!
let base = as_float64_array(&args[0]).expect("cast failed");
let exponent = as_float64_array(&args[1]).expect("cast failed");
let base = as_float64_array(&arg0).expect("cast failed");
let exponent = as_float64_array(&arg1).expect("cast failed");

// this is guaranteed by DataFusion. We place it just to make it obvious.
assert_eq!(exponent.len(), base.len());
Expand All @@ -92,11 +106,8 @@ async fn main() -> Result<()> {

// `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!)
// `Arc` because arrays are immutable, thread-safe, trait objects.
Ok(Arc::new(array) as ArrayRef)
};
// the function above expects an `ArrayRef`, but DataFusion may pass a scalar to a UDF.
// thus, we use `make_scalar_function` to decorare the closure so that it can handle both Arrays and Scalar values.
let pow = make_scalar_function(pow);
Ok(ColumnarValue::from(Arc::new(array) as ArrayRef))
});

// Next:
// * give it a name so that it shows nicely when the plan is printed
Expand Down
52 changes: 33 additions & 19 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ use arrow::compute::kernels::numeric::add;
use arrow_array::{ArrayRef, Float64Array, Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::*;
use datafusion::{
execution::registry::FunctionRegistry,
physical_plan::functions::make_scalar_function, test_util,
};
use datafusion::{execution::registry::FunctionRegistry, test_util};
use datafusion_common::cast::as_float64_array;
use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue};
use datafusion_expr::{
Expand Down Expand Up @@ -87,12 +84,18 @@ async fn scalar_udf() -> Result<()> {

ctx.register_batch("t", batch)?;

let myfunc = |args: &[ArrayRef]| {
let l = as_int32_array(&args[0])?;
let r = as_int32_array(&args[1])?;
Ok(Arc::new(add(l, r)?) as ArrayRef)
};
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(l) = &args[0] else {
panic!("should be array")
};
let ColumnarValue::Array(r) = &args[1] else {
panic!("should be array")
};

let l = as_int32_array(l)?;
let r = as_int32_array(r)?;
Ok(ColumnarValue::from(Arc::new(add(l, r)?) as ArrayRef))
});

ctx.register_udf(create_udf(
"my_add",
Expand Down Expand Up @@ -163,11 +166,14 @@ async fn scalar_udf_zero_params() -> Result<()> {

ctx.register_batch("t", batch)?;
// create function just returns 100 regardless of inp
let myfunc = |args: &[ArrayRef]| {
let num_rows = args[0].len();
Ok(Arc::new((0..num_rows).map(|_| 100).collect::<Int32Array>()) as ArrayRef)
};
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Scalar(_) = &args[0] else {
panic!("expect scalar")
};
Ok(ColumnarValue::Array(
Arc::new((0..1).map(|_| 100).collect::<Int32Array>()) as ArrayRef,
))
});

ctx.register_udf(create_udf(
"get_100",
Expand Down Expand Up @@ -307,8 +313,12 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
ctx.register_batch("t", batch).unwrap();

let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!("should be array")
};
Ok(ColumnarValue::from(Arc::clone(array)))
});

ctx.register_udf(create_udf(
"MY_FUNC",
Expand Down Expand Up @@ -348,8 +358,12 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
ctx.register_batch("t", batch).unwrap();

let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!("should be array")
};
Ok(ColumnarValue::from(Arc::clone(array)))
});

let udf = create_udf(
"dummy",
Expand Down
12 changes: 12 additions & 0 deletions datafusion/expr/src/columnar_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ pub enum ColumnarValue {
Scalar(ScalarValue),
}

impl From<ArrayRef> for ColumnarValue {
fn from(value: ArrayRef) -> Self {
ColumnarValue::Array(value)
}
}

impl From<ScalarValue> for ColumnarValue {
fn from(value: ScalarValue) -> Self {
ColumnarValue::Scalar(value)
}
}

impl ColumnarValue {
pub fn data_type(&self) -> DataType {
match self {
Expand Down
37 changes: 28 additions & 9 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1321,9 +1321,7 @@ mod tests {
assert_contains, cast::as_int32_array, plan_datafusion_err, DFField, ToDFSchema,
};
use datafusion_expr::{interval_arithmetic::Interval, *};
use datafusion_physical_expr::{
execution_props::ExecutionProps, functions::make_scalar_function,
};
use datafusion_physical_expr::execution_props::ExecutionProps;

use chrono::{DateTime, TimeZone, Utc};

Expand Down Expand Up @@ -1438,9 +1436,31 @@ mod tests {
let input_types = vec![DataType::Int32, DataType::Int32];
let return_type = Arc::new(DataType::Int32);

let fun = |args: &[ArrayRef]| {
let arg0 = as_int32_array(&args[0])?;
let arg1 = as_int32_array(&args[1])?;
let fun = Arc::new(|args: &[ColumnarValue]| {
let len = args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe as a follow on PR we can make a function that does this length inference and conversion to array (mostly so it can be documented) to make the ideas easier to find 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, have the same thought too. It will be still useful for ScalarUDFImpl as invoke takes args: &[ColumnarValue] and users might need length inference too.

.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let inferred_length = len.unwrap_or(1);

let arg0 = match &args[0] {
ColumnarValue::Array(array) => array.clone(),
ColumnarValue::Scalar(scalar) => {
scalar.to_array_of_size(inferred_length).unwrap()
}
};
let arg1 = match &args[1] {
ColumnarValue::Array(array) => array.clone(),
ColumnarValue::Scalar(scalar) => {
scalar.to_array_of_size(inferred_length).unwrap()
}
};

let arg0 = as_int32_array(&arg0)?;
let arg1 = as_int32_array(&arg1)?;

// 2. perform the computation
let array = arg0
Expand All @@ -1456,10 +1476,9 @@ mod tests {
})
.collect::<Int32Array>();

Ok(Arc::new(array) as ArrayRef)
};
Ok(ColumnarValue::from(Arc::new(array) as ArrayRef))
});

let fun = make_scalar_function(fun);
Arc::new(create_udf(
"udf_add",
input_types,
Expand Down
Loading