diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 057cdd475273..1296c74ea277 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -59,8 +59,9 @@ cargo run --example csv_sql - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP - [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass +- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) +- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) -- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) ## Distributed diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs new file mode 100644 index 000000000000..6ebf88a0b671 --- /dev/null +++ b/datafusion-examples/examples/advanced_udf.rs @@ -0,0 +1,243 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{ + arrow::{ + array::{ArrayRef, Float32Array, Float64Array}, + datatypes::DataType, + record_batch::RecordBatch, + }, + logical_expr::Volatility, +}; +use std::any::Any; + +use arrow::array::{new_null_array, Array, AsArray}; +use arrow::compute; +use arrow::datatypes::Float64Type; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::{internal_err, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; +use std::sync::Arc; + +/// This example shows how to use the full ScalarUDFImpl API to implement a user +/// defined function. As in the `simple_udf.rs` example, this struct implements +/// a function that takes two arguments and returns the first argument raised to +/// the power of the second argument `a^b`. +/// +/// To do so, we must implement the `ScalarUDFImpl` trait. +struct PowUdf { + signature: Signature, + aliases: Vec, +} + +impl PowUdf { + /// Create a new instance of the `PowUdf` struct + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take two arguments of type f64 + vec![DataType::Float64, DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + // we will also add an alias of "my_pow" + aliases: vec!["my_pow".to_string()], + } + } +} + +impl ScalarUDFImpl for PowUdf { + /// We implement as_any so that we can downcast the ScalarUDFImpl trait object + fn as_any(&self) -> &dyn Any { + self + } + + /// Return the name of this function + fn name(&self) -> &str { + "pow" + } + + /// Return the "signature" of this function -- namely what types of arguments it will take + fn signature(&self) -> &Signature { + &self.signature + } + + /// What is the type of value that will be returned by this function? In + /// this case it will always be a constant value, but it could also be a + /// function of the input types. + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + /// This is the function that actually calculates the results. + /// + /// This is the same way that functions built into DataFusion are invoked, + /// which permits important special cases when one or both of the arguments + /// are single values (constants). For example `pow(a, 2)` + /// + /// However, it also means the implementation is more complex than when + /// using `create_udf`. + fn invoke(&self, args: &[ColumnarValue]) -> Result { + // DataFusion has arranged for the correct inputs to be passed to this + // function, but we check again to make sure + assert_eq!(args.len(), 2); + let (base, exp) = (&args[0], &args[1]); + assert_eq!(base.data_type(), DataType::Float64); + assert_eq!(exp.data_type(), DataType::Float64); + + match (base, exp) { + // For demonstration purposes we also implement the scalar / scalar + // case here, but it is not typically required for high performance. + // + // For performance it is most important to optimize cases where at + // least one argument is an array. If all arguments are constants, + // the DataFusion expression simplification logic will often invoke + // this path once during planning, and simply use the result during + // execution. + ( + ColumnarValue::Scalar(ScalarValue::Float64(base)), + ColumnarValue::Scalar(ScalarValue::Float64(exp)), + ) => { + // compute the output. Note DataFusion treats `None` as NULL. + let res = match (base, exp) { + (Some(base), Some(exp)) => Some(base.powf(*exp)), + // one or both arguments were NULL + _ => None, + }; + Ok(ColumnarValue::Scalar(ScalarValue::from(res))) + } + // special case if the exponent is a constant + ( + ColumnarValue::Array(base_array), + ColumnarValue::Scalar(ScalarValue::Float64(exp)), + ) => { + let result_array = match exp { + // a ^ null = null + None => new_null_array(base_array.data_type(), base_array.len()), + // a ^ exp + Some(exp) => { + // DataFusion has ensured both arguments are Float64: + let base_array = base_array.as_primitive::(); + // calculate the result for every row. The `unary` + // kernel creates very fast "vectorized" code and + // handles things like null values for us. + let res: Float64Array = + compute::unary(base_array, |base| base.powf(*exp)); + Arc::new(res) + } + }; + Ok(ColumnarValue::Array(result_array)) + } + + // special case if the base is a constant (note this code is quite + // similar to the previous case, so we omit comments) + ( + ColumnarValue::Scalar(ScalarValue::Float64(base)), + ColumnarValue::Array(exp_array), + ) => { + let res = match base { + None => new_null_array(exp_array.data_type(), exp_array.len()), + Some(base) => { + let exp_array = exp_array.as_primitive::(); + let res: Float64Array = + compute::unary(exp_array, |exp| base.powf(exp)); + Arc::new(res) + } + }; + Ok(ColumnarValue::Array(res)) + } + // Both arguments are arrays so we have to perform the calculation for every row + (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => { + let res: Float64Array = compute::binary( + base_array.as_primitive::(), + exp_array.as_primitive::(), + |base, exp| base.powf(exp), + )?; + Ok(ColumnarValue::Array(Arc::new(res))) + } + // if the types were not float, it is a bug in DataFusion + _ => { + use datafusion_common::DataFusionError; + internal_err!("Invalid argument types to pow function") + } + } + } + + /// We will also add an alias of "my_pow" + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// In this example we register `PowUdf` as a user defined function +/// and invoke it via the DataFrame API and SQL +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context()?; + + // create the UDF + let pow = ScalarUDF::from(PowUdf::new()); + + // register the UDF with the context so it can be invoked by name and from SQL + ctx.register_udf(pow.clone()); + + // get a DataFrame from the context for scanning the "t" table + let df = ctx.table("t").await?; + + // Call pow(a, 10) using the DataFrame API + let df = df.select(vec![pow.call(vec![col("a"), lit(10i32)])])?; + + // note that the second argument is passed as an i32, not f64. DataFusion + // automatically coerces the types to match the UDF's defined signature. + + // print the results + df.show().await?; + + // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL + let sql_df = ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t").await?; + sql_df.show().await?; + + Ok(()) +} + +/// create local execution context with an in-memory table: +/// +/// ```text +/// +-----+-----+ +/// | a | b | +/// +-----+-----+ +/// | 2.1 | 1.0 | +/// | 3.1 | 2.0 | +/// | 4.1 | 3.0 | +/// | 5.1 | 4.0 | +/// +-----+-----+ +/// ``` +fn create_context() -> Result { + // define data. + let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; + + // declare a new context. In Spark API, this corresponds to a new SparkSession + let ctx = SessionContext::new(); + + // declare a table in memory. In Spark API, this corresponds to createDataFrame(...). + ctx.register_batch("t", batch)?; + Ok(ctx) +} diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index 591991786515..39e1e13ce39a 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -140,5 +140,11 @@ async fn main() -> Result<()> { // print the results df.show().await?; + // Given that `pow` is registered in the context, we can also use it in SQL: + let sql_df = ctx.sql("SELECT pow(a, b) FROM t").await?; + + // print the results + sql_df.show().await?; + Ok(()) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index b46e9ec8f69d..0ec19bcadbf6 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1724,13 +1724,13 @@ mod test { use crate::expr::Cast; use crate::expr_fn::col; use crate::{ - case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ReturnTypeFunction, - ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, Signature, - Volatility, + case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ScalarFunctionDefinition, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use arrow::datatypes::DataType; use datafusion_common::Column; use datafusion_common::{Result, ScalarValue}; + use std::any::Any; use std::sync::Arc; #[test] @@ -1848,24 +1848,41 @@ mod test { ); // UDF - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); - let fun: ScalarFunctionImplementation = - Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); - let udf = Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), - &return_type, - &fun, - )); + struct TestScalarUDF { + signature: Signature, + } + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "TestScalarUDF" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + } + } + let udf = Arc::new(ScalarUDF::from(TestScalarUDF { + signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + })); assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); - let udf = Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Float32], Volatility::Volatile), - &return_type, - &fun, - )); + let udf = Arc::new(ScalarUDF::from(TestScalarUDF { + signature: Signature::uniform( + 1, + vec![DataType::Float32], + Volatility::Volatile, + ), + })); assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); // Unresolved function diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index cedf1d845137..eed41d97ccba 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -22,15 +22,16 @@ use crate::expr::{ Placeholder, ScalarFunction, TryCast, }; use crate::function::PartitionEvaluatorFactory; -use crate::WindowUDF; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, }; +use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF}; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; +use std::any::Any; use std::ops::Not; use std::sync::Arc; @@ -944,11 +945,18 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { CaseBuilder::new(None, vec![when], vec![then], None) } -/// Creates a new UDF with a specific signature and specific return type. -/// This is a helper function to create a new UDF. -/// The function `create_udf` returns a subset of all possible `ScalarFunction`: -/// * the UDF has a fixed return type -/// * the UDF has a fixed signature (e.g. [f64, f64]) +/// Convenience method to create a new user defined scalar function (UDF) with a +/// specific signature and specific return type. +/// +/// Note this function does not expose all available features of [`ScalarUDF`], +/// such as +/// +/// * computing return types based on input types +/// * multiple [`Signature`]s +/// * aliases +/// +/// See [`ScalarUDF`] for details and examples on how to use the full +/// functionality. pub fn create_udf( name: &str, input_types: Vec, @@ -956,13 +964,66 @@ pub fn create_udf( volatility: Volatility, fun: ScalarFunctionImplementation, ) -> ScalarUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - ScalarUDF::new( + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + ScalarUDF::from(SimpleScalarUDF::new( name, - &Signature::exact(input_types, volatility), - &return_type, - &fun, - ) + input_types, + return_type, + volatility, + fun, + )) +} + +/// Implements [`ScalarUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleScalarUDF { + name: String, + signature: Signature, + return_type: DataType, + fun: ScalarFunctionImplementation, +} + +impl SimpleScalarUDF { + /// Create a new `SimpleScalarUDF` from a name, input types, return type and + /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_types: Vec, + return_type: DataType, + volatility: Volatility, + fun: ScalarFunctionImplementation, + ) -> Self { + let name = name.into(); + let signature = Signature::exact(input_types, volatility); + Self { + name, + signature, + return_type, + fun, + } + } +} + +impl ScalarUDFImpl for SimpleScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + (self.fun)(args) + } } /// Creates a new UDAF with a specific signature, state type and return type. diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 48532e13dcd7..bf8e9e2954f4 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -80,7 +80,7 @@ pub use signature::{ }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; -pub use udf::ScalarUDF; +pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::WindowUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 3a18ca2d25e8..2ec80a4a9ea1 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,9 +17,12 @@ //! [`ScalarUDF`]: Scalar User Defined Functions -use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use crate::{ + ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature, +}; use arrow::datatypes::DataType; use datafusion_common::Result; +use std::any::Any; use std::fmt; use std::fmt::Debug; use std::fmt::Formatter; @@ -27,11 +30,19 @@ use std::sync::Arc; /// Logical representation of a Scalar User Defined Function. /// -/// A scalar function produces a single row output for each row of input. +/// A scalar function produces a single row output for each row of input. This +/// struct contains the information DataFusion needs to plan and invoke +/// functions you supply such name, type signature, return type, and actual +/// implementation. /// -/// This struct contains the information DataFusion needs to plan and invoke -/// functions such name, type signature, return type, and actual implementation. /// +/// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`]. +/// +/// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`]. +/// +/// [`create_udf`]: crate::expr_fn::create_udf +/// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs +/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs #[derive(Clone)] pub struct ScalarUDF { /// The name of the function @@ -79,7 +90,11 @@ impl std::hash::Hash for ScalarUDF { } impl ScalarUDF { - /// Create a new ScalarUDF + /// Create a new ScalarUDF from low level details. + /// + /// See [`ScalarUDFImpl`] for a more convenient way to create a + /// `ScalarUDF` using trait objects + #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl instead")] pub fn new( name: &str, signature: &Signature, @@ -95,6 +110,34 @@ impl ScalarUDF { } } + /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object + /// + /// Note this is the same as using the `From` impl (`ScalarUDF::from`) + pub fn new_from_impl(fun: F) -> ScalarUDF + where + F: ScalarUDFImpl + Send + Sync + 'static, + { + // TODO change the internal implementation to use the trait object + let arc_fun = Arc::new(fun); + let captured_self = arc_fun.clone(); + let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { + let return_type = captured_self.return_type(arg_types)?; + Ok(Arc::new(return_type)) + }); + + let captured_self = arc_fun.clone(); + let func: ScalarFunctionImplementation = + Arc::new(move |args| captured_self.invoke(args)); + + Self { + name: arc_fun.name().to_string(), + signature: arc_fun.signature().clone(), + return_type: return_type.clone(), + fun: func, + aliases: arc_fun.aliases().to_vec(), + } + } + /// Adds additional names that can be used to invoke this function, in addition to `name` pub fn with_aliases( mut self, @@ -105,7 +148,9 @@ impl ScalarUDF { self } - /// creates a logical expression with a call of the UDF + /// Returns a [`Expr`] logical expression to call this UDF with specified + /// arguments. + /// /// This utility allows using the UDF without requiring access to the registry. pub fn call(&self, args: Vec) -> Expr { Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf( @@ -124,22 +169,126 @@ impl ScalarUDF { &self.aliases } - /// Returns this function's signature (what input types are accepted) + /// Returns this function's [`Signature`] (what input types are accepted) pub fn signature(&self) -> &Signature { &self.signature } - /// Return the type of the function given its input types + /// The datatype this function returns given the input argument input types pub fn return_type(&self, args: &[DataType]) -> Result { // Old API returns an Arc of the datatype for some reason let res = (self.return_type)(args)?; Ok(res.as_ref().clone()) } - /// Return the actual implementation + /// Return an [`Arc`] to the function implementation pub fn fun(&self) -> ScalarFunctionImplementation { self.fun.clone() } +} - // TODO maybe add an invoke() method that runs the actual function? +impl From for ScalarUDF +where + F: ScalarUDFImpl + Send + Sync + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// Trait for implementing [`ScalarUDF`]. +/// +/// This trait exposes the full API for implementing user defined functions and +/// can be used to implement any function. +/// +/// See [`advanced_udf.rs`] for a full example with complete implementation and +/// [`ScalarUDF`] for other available options. +/// +/// +/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; +/// struct AddOne { +/// signature: Signature +/// }; +/// +/// impl AddOne { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) +/// } +/// } +/// } +/// +/// /// Implement the ScalarUDFImpl trait for AddOne +/// impl ScalarUDFImpl for AddOne { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "add_one" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Int32)) { +/// return plan_err!("add_one only accepts Int32 arguments"); +/// } +/// Ok(DataType::Int32) +/// } +/// // The actual implementation would add one to the argument +/// fn invoke(&self, args: &[ColumnarValue]) -> Result { unimplemented!() } +/// } +/// +/// // Create a new ScalarUDF from the implementation +/// let add_one = ScalarUDF::from(AddOne::new()); +/// +/// // Call the function `add_one(col)` +/// let expr = add_one.call(vec![col("a")]); +/// ``` +pub trait ScalarUDFImpl { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns the function's [`Signature`] for information about what input + /// types are accepted and the function's Volatility. + fn signature(&self) -> &Signature; + + /// What [`DataType`] will be returned by this function, given the types of + /// the arguments + fn return_type(&self, arg_types: &[DataType]) -> Result; + + /// Invoke the function on `args`, returning the appropriate result + /// + /// The function will be invoked passed with the slice of [`ColumnarValue`] + /// (either scalar or array). + /// + /// # Zero Argument Functions + /// If the function has zero parameters (e.g. `now()`) it will be passed a + /// single element slice which is a a null array to indicate the batch's row + /// count (so the function can know the resulting array size). + /// + /// # Performance + /// + /// For the best performance, the implementations of `invoke` should handle + /// the common case when one or more of their arguments are constant values + /// (aka [`ColumnarValue::Scalar`]). Calling [`ColumnarValue::into_array`] + /// and treating all arguments as arrays will work, but will be slower. + fn invoke(&self, args: &[ColumnarValue]) -> Result; + + /// Returns any aliases (alternate names) for this function. + /// + /// Aliases can be used to invoke the same function using different names. + /// For example in some databases `now()` and `current_timestamp()` are + /// aliases for the same function. This behavior can be obtained by + /// returning `current_timestamp` as an alias for the `now` function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] + } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index c5e1180b9f97..b6298f5b552f 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -738,7 +738,8 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { #[cfg(test)] mod test { - use std::sync::Arc; + use std::any::Any; + use std::sync::{Arc, OnceLock}; use arrow::array::{FixedSizeListArray, Int32Array}; use arrow::datatypes::{DataType, TimeUnit}; @@ -750,13 +751,13 @@ mod test { use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, - ColumnarValue, ExprSchemable, Filter, Operator, StateTypeFunction, Subquery, + ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, StateTypeFunction, + Subquery, }; use datafusion_expr::{ lit, logical_plan::{EmptyRelation, Projection}, - Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, - Signature, Volatility, + Expr, LogicalPlan, ReturnTypeFunction, ScalarUDF, Signature, Volatility, }; use datafusion_physical_expr::expressions::AvgAccumulator; @@ -808,22 +809,36 @@ mod test { assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) } + static TEST_SIGNATURE: OnceLock = OnceLock::new(); + + struct TestScalarUDF {} + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "TestScalarUDF" + } + fn signature(&self) -> &Signature { + TEST_SIGNATURE.get_or_init(|| { + Signature::uniform(1, vec![DataType::Float32], Volatility::Stable) + }) + } + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + } + } + #[test] fn scalar_udf() -> Result<()> { let empty = empty(); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); - let fun: ScalarFunctionImplementation = - Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); - let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf( - Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), - &return_type, - &fun, - )), - vec![lit(123_i32)], - )); + + let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit(123_i32)]); let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); let expected = "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation"; @@ -833,24 +848,13 @@ mod test { #[test] fn scalar_udf_invalid_input() -> Result<()> { let empty = empty(); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); - let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!()); - let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf( - Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Int32], Volatility::Stable), - &return_type, - &fun, - )), - vec![lit("Apple")], - )); + let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit("Apple")]); let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "") .err() .unwrap(); assert_eq!( - "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Int32]) failed.", + "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Float32]) failed.", err.strip_backtrace() ); Ok(()) diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index 11cf52eb3fcf..c51e4de3236c 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -76,7 +76,9 @@ The challenge however is that DataFusion doesn't know about this function. We ne ### Registering a Scalar UDF -To register a Scalar UDF, you need to wrap the function implementation in a `ScalarUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udf` and `make_scalar_function` helper functions to make this easier. +To register a Scalar UDF, you need to wrap the function implementation in a [`ScalarUDF`] struct and then register it with the `SessionContext`. +DataFusion provides the [`create_udf`] and helper functions to make this easier. +There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udf.rs`]. ```rust use datafusion::logical_expr::{Volatility, create_udf}; @@ -93,6 +95,11 @@ let udf = create_udf( ); ``` +[`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html +[`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html +[`make_scalar_function`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.make_scalar_function.html +[`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs + A few things to note: - The first argument is the name of the function. This is the name that will be used in SQL queries.