Skip to content

Commit

Permalink
Add helper function for processing scalar function input (apache#8962)
Browse files Browse the repository at this point in the history
* Add helper function for scalar function

* Update datafusion/physical-expr/src/functions.rs

Co-authored-by: Andrew Lamb <[email protected]>

* Fix

* Fix

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
viirya and alamb authored Jan 24, 2024
1 parent d81c82d commit d6ab343
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 43 deletions.
18 changes: 4 additions & 14 deletions datafusion-examples/examples/simple_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::cast::as_float64_array;
use datafusion_expr::ColumnarValue;
use datafusion_physical_expr::functions::columnar_values_to_array;
use std::sync::Arc;

/// create local execution context with an in-memory table:
Expand Down Expand Up @@ -70,22 +71,11 @@ async fn main() -> Result<()> {
// this is guaranteed by DataFusion based on the function's signature.
assert_eq!(args.len(), 2);

// Try to obtain row number
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let inferred_length = len.unwrap_or(1);

let arg0 = args[0].clone().into_array(inferred_length)?;
let arg1 = args[1].clone().into_array(inferred_length)?;
let args = columnar_values_to_array(args)?;

// 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics!
let base = as_float64_array(&arg0).expect("cast failed");
let exponent = as_float64_array(&arg1).expect("cast failed");
let base = as_float64_array(&args[0]).expect("cast failed");
let exponent = as_float64_array(&args[1]).expect("cast failed");

// this is guaranteed by DataFusion. We place it just to make it obvious.
assert_eq!(exponent.len(), base.len());
Expand Down
27 changes: 4 additions & 23 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,7 @@ mod tests {
use datafusion_physical_expr::execution_props::ExecutionProps;

use chrono::{DateTime, TimeZone, Utc};
use datafusion_physical_expr::functions::columnar_values_to_array;

// ------------------------------
// --- ExprSimplifier tests -----
Expand Down Expand Up @@ -1489,30 +1490,10 @@ mod tests {
let return_type = Arc::new(DataType::Int32);

let fun = Arc::new(|args: &[ColumnarValue]| {
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let inferred_length = len.unwrap_or(1);

let arg0 = match &args[0] {
ColumnarValue::Array(array) => array.clone(),
ColumnarValue::Scalar(scalar) => {
scalar.to_array_of_size(inferred_length).unwrap()
}
};
let arg1 = match &args[1] {
ColumnarValue::Array(array) => array.clone(),
ColumnarValue::Scalar(scalar) => {
scalar.to_array_of_size(inferred_length).unwrap()
}
};
let args = columnar_values_to_array(args)?;

let arg0 = as_int32_array(&arg0)?;
let arg1 = as_int32_array(&arg1)?;
let arg0 = as_int32_array(&args[0])?;
let arg1 = as_int32_array(&args[1])?;

// 2. perform the computation
let array = arg0
Expand Down
46 changes: 46 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use arrow::{
compute::kernels::length::{bit_length, length},
datatypes::{DataType, Int32Type, Int64Type, Schema},
};
use arrow_array::Array;
use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
pub use datafusion_expr::FuncMonotonicity;
use datafusion_expr::{
Expand Down Expand Up @@ -191,6 +192,51 @@ pub(crate) enum Hint {
AcceptsSingular,
}

/// A helper function used to infer the length of arguments of Scalar functions and convert
/// [`ColumnarValue`]s to [`ArrayRef`]s with the inferred length. Note that this function
/// only works for functions that accept either that all arguments are scalars or all arguments
/// are arrays with same length. Otherwise, it will return an error.
pub fn columnar_values_to_array(args: &[ColumnarValue]) -> Result<Vec<ArrayRef>> {
if args.is_empty() {
return Ok(vec![]);
}

let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) if acc.is_none() => Some(1),
ColumnarValue::Scalar(_) => {
if let Some(1) = acc {
acc
} else {
None
}
}
ColumnarValue::Array(a) => {
if let Some(l) = acc {
if l == a.len() {
acc
} else {
None
}
} else {
Some(a.len())
}
}
});

let inferred_length = len.ok_or(DataFusionError::Internal(
"Arguments has mixed length".to_string(),
))?;

let args = args
.iter()
.map(|arg| arg.clone().into_array(inferred_length))
.collect::<Result<Vec<_>>>()?;

Ok(args)
}

/// Decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function
/// and vice-versa after evaluation.
/// Note that this function makes a scalar function with no arguments or all scalar inputs return a scalar.
Expand Down
11 changes: 5 additions & 6 deletions docs/source/library-user-guide/adding-udfs.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ use std::sync::Arc;

use datafusion::arrow::array::{ArrayRef, Int64Array};
use datafusion::common::Result;

use datafusion::common::cast::as_int64_array;
use datafusion::physical_plan::functions::columnar_values_to_array;

pub fn add_one(args: &[ArrayRef]) -> Result<ArrayRef> {
pub fn add_one(args: &[ColumnarValue]) -> Result<ArrayRef> {
// Error handling omitted for brevity

let args = columnar_values_to_array(args)?;
let i64s = as_int64_array(&args[0])?;

let new_array = i64s
Expand Down Expand Up @@ -82,7 +82,6 @@ There is a lower level API with more functionality but is more complex, that is

```rust
use datafusion::logical_expr::{Volatility, create_udf};
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::arrow::datatypes::DataType;
use std::sync::Arc;

Expand All @@ -91,13 +90,13 @@ let udf = create_udf(
vec![DataType::Int64],
Arc::new(DataType::Int64),
Volatility::Immutable,
make_scalar_function(add_one),
Arc::new(add_one),
);
```

[`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html
[`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html
[`make_scalar_function`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.make_scalar_function.html
[`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html
[`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs

A few things to note:
Expand Down

0 comments on commit d6ab343

Please sign in to comment.