From 9bc03ed4eb257c04a69c3c1333cf25d856ec8568 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Tue, 9 Apr 2024 12:20:03 -0400 Subject: [PATCH 1/2] Move Nanvl and random functions to datafusion-functions --- datafusion-cli/Cargo.lock | 2 +- datafusion/core/tests/simplification.rs | 15 +- datafusion/expr/src/built_in_function.rs | 21 --- datafusion/expr/src/expr.rs | 11 +- datafusion/expr/src/expr_fn.rs | 7 - datafusion/expr/src/signature.rs | 2 +- datafusion/functions/Cargo.toml | 1 + datafusion/functions/src/math/mod.rs | 16 ++ datafusion/functions/src/math/nanvl.rs | 165 ++++++++++++++++++ datafusion/functions/src/math/random.rs | 108 ++++++++++++ datafusion/optimizer/src/push_down_filter.rs | 66 +++++-- datafusion/physical-expr/Cargo.toml | 1 - datafusion/physical-expr/src/functions.rs | 36 ++-- .../physical-expr/src/math_expressions.rs | 142 +-------------- datafusion/proto/proto/datafusion.proto | 4 +- datafusion/proto/src/generated/pbjson.rs | 6 - datafusion/proto/src/generated/prost.rs | 8 +- .../proto/src/logical_plan/from_proto.rs | 12 +- datafusion/proto/src/logical_plan/to_proto.rs | 2 - 19 files changed, 383 insertions(+), 242 deletions(-) create mode 100644 datafusion/functions/src/math/nanvl.rs create mode 100644 datafusion/functions/src/math/random.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 447b69e414cf..a38fd7f1fe28 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1273,6 +1273,7 @@ dependencies = [ "itertools", "log", "md-5", + "rand", "regex", "sha2", "unicode-segmentation", @@ -1355,7 +1356,6 @@ dependencies = [ "md-5", "paste", "petgraph", - "rand", "regex", "sha2", ] diff --git a/datafusion/core/tests/simplification.rs b/datafusion/core/tests/simplification.rs index 5a2f040c09d8..a0bcdda84d64 100644 --- a/datafusion/core/tests/simplification.rs +++ b/datafusion/core/tests/simplification.rs @@ -28,9 +28,10 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::builder::table_scan_with_filters; use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::{ - expr, table_scan, BuiltinScalarFunction, Cast, ColumnarValue, Expr, ExprSchemable, - LogicalPlan, LogicalPlanBuilder, ScalarUDF, Volatility, + expr, table_scan, Cast, ColumnarValue, Expr, ExprSchemable, LogicalPlan, + LogicalPlanBuilder, ScalarUDF, Volatility, }; +use datafusion_functions::math; use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions}; use datafusion_optimizer::{OptimizerContext, OptimizerRule}; use std::sync::Arc; @@ -383,17 +384,17 @@ fn test_const_evaluator_scalar_functions() { // volatile / stable functions should not be evaluated // rand() + (1 + 2) --> rand() + 3 - let fun = BuiltinScalarFunction::Random; - assert_eq!(fun.volatility(), Volatility::Volatile); - let rand = Expr::ScalarFunction(ScalarFunction::new(fun, vec![])); + let fun = math::random(); + assert_eq!(fun.signature().volatility, Volatility::Volatile); + let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![])); let expr = rand.clone() + (lit(1) + lit(2)); let expected = rand + lit(3); test_evaluate(expr, expected); // parenthesization matters: can't rewrite // (rand() + 1) + 2 --> (rand() + 1) + 2) - let fun = BuiltinScalarFunction::Random; - let rand = Expr::ScalarFunction(ScalarFunction::new(fun, vec![])); + let fun = math::random(); + let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![])); let expr = (rand + lit(1)) + lit(2); test_evaluate(expr.clone(), expr); } diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index a6795e99d751..43cb0c3e0a50 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -45,8 +45,6 @@ pub enum BuiltinScalarFunction { Exp, /// factorial Factorial, - /// nanvl - Nanvl, // string functions /// concat Concat, @@ -56,8 +54,6 @@ pub enum BuiltinScalarFunction { EndsWith, /// initcap InitCap, - /// random - Random, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -114,14 +110,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => Volatility::Immutable, BuiltinScalarFunction::Exp => Volatility::Immutable, BuiltinScalarFunction::Factorial => Volatility::Immutable, - BuiltinScalarFunction::Nanvl => Volatility::Immutable, BuiltinScalarFunction::Concat => Volatility::Immutable, BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, BuiltinScalarFunction::EndsWith => Volatility::Immutable, BuiltinScalarFunction::InitCap => Volatility::Immutable, - - // Volatile builtin functions - BuiltinScalarFunction::Random => Volatility::Volatile, } } @@ -152,16 +144,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::InitCap => { utf8_to_str_type(&input_expr_types[0], "initcap") } - BuiltinScalarFunction::Random => Ok(Float64), BuiltinScalarFunction::EndsWith => Ok(Boolean), BuiltinScalarFunction::Factorial => Ok(Int64), - BuiltinScalarFunction::Nanvl => match &input_expr_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - }, - BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => { match input_expr_types[0] { Float32 => Ok(Float32), @@ -199,11 +185,6 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()), - BuiltinScalarFunction::Nanvl => Signature::one_of( - vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], - self.volatility(), - ), BuiltinScalarFunction::Factorial => { Signature::uniform(1, vec![Int64], self.volatility()) } @@ -240,8 +221,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Ceil => &["ceil"], BuiltinScalarFunction::Exp => &["exp"], BuiltinScalarFunction::Factorial => &["factorial"], - BuiltinScalarFunction::Nanvl => &["nanvl"], - BuiltinScalarFunction::Random => &["random"], // conditional functions BuiltinScalarFunction::Coalesce => &["coalesce"], diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ad15a81a2325..c7c50d871902 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1903,8 +1903,8 @@ mod test { use crate::expr::Cast; use crate::expr_fn::col; use crate::{ - case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ScalarFunctionDefinition, - ScalarUDF, ScalarUDFImpl, Signature, Volatility, + case, lit, ColumnarValue, Expr, ScalarFunctionDefinition, ScalarUDF, + ScalarUDFImpl, Signature, Volatility, }; use arrow::datatypes::DataType; use datafusion_common::Column; @@ -2018,13 +2018,6 @@ mod test { #[test] fn test_is_volatile_scalar_func_definition() { - // BuiltIn - assert!( - ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random) - .is_volatile() - .unwrap() - ); - // UDF #[derive(Debug)] struct TestScalarUDF { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1e28e27af1e0..6a28275ebfcf 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -297,11 +297,6 @@ pub fn concat_ws(sep: Expr, values: Vec) -> Expr { )) } -/// Returns a random value in the range 0.0 <= x < 1.0 -pub fn random() -> Expr { - Expr::ScalarFunction(ScalarFunction::new(BuiltinScalarFunction::Random, vec![])) -} - /// Returns the approximate number of distinct input values. /// This function provides an approximation of count(DISTINCT x). /// Zero is returned if all input values are null. @@ -550,7 +545,6 @@ nary_scalar_expr!( "concatenates several strings, placing a seperator between each one" ); nary_scalar_expr!(Concat, concat_expr, "concatenates several strings"); -scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y"); /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. pub fn case(expr: Expr) -> CaseBuilder { @@ -922,7 +916,6 @@ mod test { test_unary_scalar_expr!(Factorial, factorial); test_unary_scalar_expr!(Ceil, ceil); test_unary_scalar_expr!(Exp, exp); - test_scalar_expr!(Nanvl, nanvl, x, y); test_scalar_expr!(InitCap, initcap, string); test_scalar_expr!(EndsWith, ends_with, string, characters); diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 89f456f337f9..e2505d6fd65f 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -51,7 +51,7 @@ pub enum Volatility { Stable, /// A volatile function may change the return value from evaluation to evaluation. /// Multiple invocations of a volatile function may return different results when used in the - /// same query. An example of this is [super::BuiltinScalarFunction::Random]. DataFusion + /// same query. An example of this is the random() function. DataFusion /// can not evaluate such functions during planning. /// In the query `select col1, random() from t1`, `random()` function will be evaluated /// for each output row, resulting in a unique random value for each row. diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index ef7d2c9b1892..a6847f3327c0 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -77,6 +77,7 @@ hex = { version = "0.4", optional = true } itertools = { workspace = true } log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } +rand = { workspace = true } regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 544de04e4a98..c83a98cb1913 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -27,8 +27,10 @@ pub mod iszero; pub mod lcm; pub mod log; pub mod nans; +pub mod nanvl; pub mod pi; pub mod power; +pub mod random; pub mod round; pub mod trunc; @@ -55,9 +57,11 @@ make_udf_function!(lcm::LcmFunc, LCM, lcm); make_math_unary_udf!(LnFunc, LN, ln, ln, Some(vec![Some(true)])); make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); +make_udf_function!(nanvl::NanvlFunc, NANVL, nanvl); make_udf_function!(pi::PiFunc, PI, pi); make_udf_function!(power::PowerFunc, POWER, power); make_math_unary_udf!(RadiansFunc, RADIANS, radians, to_radians, None); +make_udf_function!(random::RandomFunc, RANDOM, random); make_udf_function!(round::RoundFunc, ROUND, round); make_math_unary_udf!(SignumFunc, SIGNUM, signum, signum, None); make_math_unary_udf!(SinFunc, SIN, sin, sin, None); @@ -180,6 +184,11 @@ pub mod expr_fn { super::log10().call(vec![num]) } + #[doc = "returns x if x is not NaN otherwise returns y"] + pub fn nanvl(x: Expr, y: Expr) -> Expr { + super::nanvl().call(vec![x, y]) + } + #[doc = "Returns an approximate value of π"] pub fn pi() -> Expr { super::pi().call(vec![]) @@ -195,6 +204,11 @@ pub mod expr_fn { super::radians().call(vec![num]) } + #[doc = "Returns a random value in the range 0.0 <= x < 1.0"] + pub fn random() -> Expr { + super::random().call(vec![]) + } + #[doc = "round to nearest integer"] pub fn round(args: Vec) -> Expr { super::round().call(args) @@ -261,9 +275,11 @@ pub fn functions() -> Vec> { log(), log2(), log10(), + nanvl(), pi(), power(), radians(), + random(), round(), signum(), sin(), diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs new file mode 100644 index 000000000000..d81a690843b6 --- /dev/null +++ b/datafusion/functions/src/math/nanvl.rs @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float32Array, Float64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{Float32, Float64}; + +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::make_scalar_function; + +#[derive(Debug)] +pub struct NanvlFunc { + signature: Signature, +} + +impl Default for NanvlFunc { + fn default() -> Self { + NanvlFunc::new() + } +} + +impl NanvlFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for NanvlFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "nanvl" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(nanvl, vec![])(args) + } +} + +/// Nanvl SQL function +fn nanvl(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + Float64 => { + let compute_nanvl = |x: f64, y: f64| { + if x.is_nan() { + y + } else { + x + } + }; + + Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "x", + "y", + Float64Array, + { compute_nanvl } + )) as ArrayRef) + } + Float32 => { + let compute_nanvl = |x: f32, y: f32| { + if x.is_nan() { + y + } else { + x + } + }; + + Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "x", + "y", + Float32Array, + { compute_nanvl } + )) as ArrayRef) + } + other => exec_err!("Unsupported data type {other:?} for function nanvl"), + } +} + +#[cfg(test)] +mod test { + use crate::math::nanvl::nanvl; + use arrow::array::{ArrayRef, Float32Array, Float64Array}; + use datafusion_common::cast::{as_float32_array, as_float64_array}; + use std::sync::Arc; + + #[test] + fn test_nanvl_f64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // y + Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // x + ]; + + let result = nanvl(&args).expect("failed to initialize function nanvl"); + let floats = + as_float64_array(&result).expect("failed to initialize function nanvl"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), 1.0); + assert_eq!(floats.value(1), 6.0); + assert_eq!(floats.value(2), 3.0); + assert!(floats.value(3).is_nan()); + } + + #[test] + fn test_nanvl_f32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // y + Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // x + ]; + + let result = nanvl(&args).expect("failed to initialize function nanvl"); + let floats = + as_float32_array(&result).expect("failed to initialize function nanvl"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), 1.0); + assert_eq!(floats.value(1), 6.0); + assert_eq!(floats.value(2), 3.0); + assert!(floats.value(3).is_nan()); + } +} diff --git a/datafusion/functions/src/math/random.rs b/datafusion/functions/src/math/random.rs new file mode 100644 index 000000000000..2c1ad4136702 --- /dev/null +++ b/datafusion/functions/src/math/random.rs @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::iter; +use std::sync::Arc; + +use arrow::array::Float64Array; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Float64; +use rand::{thread_rng, Rng}; + +use datafusion_common::{exec_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub struct RandomFunc { + signature: Signature, +} + +impl Default for RandomFunc { + fn default() -> Self { + RandomFunc::new() + } +} + +impl RandomFunc { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Volatile), + } + } +} + +impl ScalarUDFImpl for RandomFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "random" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Float64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + random(args) + } +} + +/// Random SQL function +fn random(args: &[ColumnarValue]) -> Result { + let len: usize = match &args[0] { + ColumnarValue::Array(array) => array.len(), + _ => return exec_err!("Expect random function to take no param"), + }; + let mut rng = thread_rng(); + let values = iter::repeat_with(|| rng.gen_range(0.0..1.0)).take(len); + let array = Float64Array::from_iter_values(values); + Ok(ColumnarValue::Array(Arc::new(array))) +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::NullArray; + + use datafusion_common::cast::as_float64_array; + use datafusion_expr::ColumnarValue; + + use crate::math::random::random; + + #[test] + fn test_random_expression() { + let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))]; + let array = random(&args) + .expect("failed to initialize function random") + .into_array(1) + .expect("Failed to convert to array"); + let floats = + as_float64_array(&array).expect("failed to initialize function random"); + + assert_eq!(floats.len(), 1); + assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0); + } +} diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index ff24df259adf..f3ce8bbcde72 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1028,6 +1028,7 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { #[cfg(test)] mod tests { + use std::any::Any; use std::fmt::{Debug, Formatter}; use std::sync::Arc; @@ -1038,15 +1039,17 @@ mod tests { use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::{DFSchema, DFSchemaRef}; + use datafusion_common::{DFSchema, DFSchemaRef, ScalarValue}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ - and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, random, sum, - BinaryExpr, Expr, Extension, LogicalPlanBuilder, Operator, TableSource, - TableType, UserDefinedLogicalNodeCore, + and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr, + ColumnarValue, Expr, Extension, LogicalPlanBuilder, Operator, ScalarUDF, + ScalarUDFImpl, Signature, TableSource, TableType, UserDefinedLogicalNodeCore, + Volatility, }; use async_trait::async_trait; + use datafusion_expr::expr::ScalarFunction; fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( @@ -2859,17 +2862,44 @@ Projection: a, b assert_optimized_plan_eq(&plan, expected) } + #[derive(Debug)] + struct TestScalarUDF { + signature: Signature, + } + + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "TestScalarUDF" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::from(1))) + } + } + #[test] fn test_push_down_volatile_function_in_aggregate() -> Result<()> { - // SELECT t.a, t.r FROM (SELECT a, SUM(b), random()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5; + // SELECT t.a, t.r FROM (SELECT a, SUM(b), TestScalarUDF()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5; let table_scan = test_table_scan_with_name("test1")?; + let fun = ScalarUDF::new_from_impl(TestScalarUDF { + signature: Signature::exact(vec![], Volatility::Volatile), + }); + let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); + let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? - .project(vec![ - col("a"), - sum(col("b")), - add(random(), lit(1)).alias("r"), - ])? + .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])? .alias("t")? .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))? .project(vec![col("t.a"), col("t.r")])? @@ -2878,7 +2908,7 @@ Projection: a, b let expected_before = "Projection: t.a, t.r\ \n Filter: t.a > Int32(5) AND t.r > Float64(0.5)\ \n SubqueryAlias: t\ - \n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\ + \n Projection: test1.a, SUM(test1.b), TestScalarUDF() + Int32(1) AS r\ \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ \n TableScan: test1"; assert_eq!(format!("{plan:?}"), expected_before); @@ -2886,7 +2916,7 @@ Projection: a, b let expected_after = "Projection: t.a, t.r\ \n SubqueryAlias: t\ \n Filter: r > Float64(0.5)\ - \n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\ + \n Projection: test1.a, SUM(test1.b), TestScalarUDF() + Int32(1) AS r\ \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ \n TableScan: test1, full_filters=[test1.a > Int32(5)]"; assert_optimized_plan_eq(&plan, expected_after) @@ -2894,8 +2924,12 @@ Projection: a, b #[test] fn test_push_down_volatile_function_in_join() -> Result<()> { - // SELECT t.a, t.r FROM (SELECT test1.a AS a, random() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5; + // SELECT t.a, t.r FROM (SELECT test1.a AS a, TestScalarUDF() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5; let table_scan = test_table_scan_with_name("test1")?; + let fun = ScalarUDF::new_from_impl(TestScalarUDF { + signature: Signature::exact(vec![], Volatility::Volatile), + }); + let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); let left = LogicalPlanBuilder::from(table_scan).build()?; let right_table_scan = test_table_scan_with_name("test2")?; let right = LogicalPlanBuilder::from(right_table_scan).build()?; @@ -2909,7 +2943,7 @@ Projection: a, b ), None, )? - .project(vec![col("test1.a").alias("a"), random().alias("r")])? + .project(vec![col("test1.a").alias("a"), expr.alias("r")])? .alias("t")? .filter(col("t.r").gt(lit(0.8)))? .project(vec![col("t.a"), col("t.r")])? @@ -2918,7 +2952,7 @@ Projection: a, b let expected_before = "Projection: t.a, t.r\ \n Filter: t.r > Float64(0.8)\ \n SubqueryAlias: t\ - \n Projection: test1.a AS a, random() AS r\ + \n Projection: test1.a AS a, TestScalarUDF() AS r\ \n Inner Join: test1.a = test2.a\ \n TableScan: test1\ \n TableScan: test2"; @@ -2927,7 +2961,7 @@ Projection: a, b let expected = "Projection: t.a, t.r\ \n SubqueryAlias: t\ \n Filter: r > Float64(0.8)\ - \n Projection: test1.a AS a, random() AS r\ + \n Projection: test1.a AS a, TestScalarUDF() AS r\ \n Inner Join: test1.a = test2.a\ \n TableScan: test1\ \n TableScan: test2"; diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 72fac5370ae0..423087d2182b 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -70,7 +70,6 @@ log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } paste = "^1.0" petgraph = "0.6.2" -rand = { workspace = true } regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 2be85a69d7da..db3e53246e13 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -184,10 +184,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Factorial => { Arc::new(|args| make_scalar_function_inner(math_expressions::factorial)(args)) } - BuiltinScalarFunction::Nanvl => { - Arc::new(|args| make_scalar_function_inner(math_expressions::nanvl)(args)) - } - BuiltinScalarFunction::Random => Arc::new(math_expressions::random), // string functions BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce), BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), @@ -290,12 +286,13 @@ mod tests { datatypes::Field, record_batch::RecordBatch, }; + use std::any::Any; use datafusion_common::cast::as_uint64_array; use datafusion_common::{internal_err, plan_err}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::functions::data_types; - use datafusion_expr::Signature; + use datafusion_expr::{ScalarUDFImpl, Signature}; use crate::expressions::lit; use crate::expressions::try_cast; @@ -542,17 +539,30 @@ mod tests { Ok(()) } - #[test] - fn test_empty_arguments() -> Result<()> { - let execution_props = ExecutionProps::new(); - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + #[derive(Debug)] + struct EmptyArgsUDF { + signature: Signature, + } - let funs = [BuiltinScalarFunction::Random]; + impl ScalarUDFImpl for EmptyArgsUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "EmptyArgsUDF" + } - for fun in funs.iter() { - create_physical_expr_with_type_coercion(fun, &[], &schema, &execution_props)?; + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) } - Ok(()) } // Helper function just for testing. diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 55fb54563787..004a9abe7f0b 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -18,14 +18,12 @@ //! Math expressions use std::any::type_name; -use std::iter; use std::sync::Arc; use arrow::array::ArrayRef; use arrow::array::{BooleanArray, Float32Array, Float64Array, Int64Array}; use arrow::datatypes::DataType; use arrow_array::Array; -use rand::{thread_rng, Rng}; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; @@ -113,33 +111,6 @@ macro_rules! make_function_scalar_inputs { }}; } -macro_rules! make_function_inputs2 { - ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); - let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE); - - arg1.iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; - ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{ - let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1); - let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2); - - arg1.iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), - _ => None, - }) - .collect::<$ARRAY_TYPE1>() - }}; -} - macro_rules! make_function_scalar_inputs_return_type { ($ARG: expr, $NAME:expr, $ARGS_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ let arg = downcast_arg!($ARG, $NAME, $ARGS_TYPE); @@ -169,51 +140,6 @@ pub fn factorial(args: &[ArrayRef]) -> Result { } } -/// Nanvl SQL function -pub fn nanvl(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => { - let compute_nanvl = |x: f64, y: f64| { - if x.is_nan() { - y - } else { - x - } - }; - - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Float64Array, - { compute_nanvl } - )) as ArrayRef) - } - - DataType::Float32 => { - let compute_nanvl = |x: f32, y: f32| { - if x.is_nan() { - y - } else { - x - } - }; - - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Float32Array, - { compute_nanvl } - )) as ArrayRef) - } - - other => exec_err!("Unsupported data type {other:?} for function nanvl"), - } -} - /// Isnan SQL function pub fn isnan(args: &[ArrayRef]) -> Result { match args[0].data_type() { @@ -237,42 +163,14 @@ pub fn isnan(args: &[ArrayRef]) -> Result { } } -/// Random SQL function -pub fn random(args: &[ColumnarValue]) -> Result { - let len: usize = match &args[0] { - ColumnarValue::Array(array) => array.len(), - _ => return exec_err!("Expect random function to take no param"), - }; - let mut rng = thread_rng(); - let values = iter::repeat_with(|| rng.gen_range(0.0..1.0)).take(len); - let array = Float64Array::from_iter_values(values); - Ok(ColumnarValue::Array(Arc::new(array))) -} - #[cfg(test)] mod tests { - use arrow::array::{Float64Array, NullArray}; + use arrow::array::Float64Array; - use datafusion_common::cast::{ - as_boolean_array, as_float32_array, as_float64_array, as_int64_array, - }; + use datafusion_common::cast::{as_boolean_array, as_int64_array}; use super::*; - #[test] - fn test_random_expression() { - let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))]; - let array = random(&args) - .expect("failed to initialize function random") - .into_array(1) - .expect("Failed to convert to array"); - let floats = - as_float64_array(&array).expect("failed to initialize function random"); - - assert_eq!(floats.len(), 1); - assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0); - } - #[test] fn test_factorial_i64() { let args: Vec = vec![ @@ -288,42 +186,6 @@ mod tests { assert_eq!(ints, &expected); } - #[test] - fn test_nanvl_f64() { - let args: Vec = vec![ - Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // y - Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // x - ]; - - let result = nanvl(&args).expect("failed to initialize function nanvl"); - let floats = - as_float64_array(&result).expect("failed to initialize function nanvl"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), 1.0); - assert_eq!(floats.value(1), 6.0); - assert_eq!(floats.value(2), 3.0); - assert!(floats.value(3).is_nan()); - } - - #[test] - fn test_nanvl_f32() { - let args: Vec = vec![ - Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // y - Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // x - ]; - - let result = nanvl(&args).expect("failed to initialize function nanvl"); - let floats = - as_float32_array(&result).expect("failed to initialize function nanvl"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), 1.0); - assert_eq!(floats.value(1), 6.0); - assert_eq!(floats.value(2), 3.0); - assert!(floats.value(3).is_nan()); - } - #[test] fn test_isnan_f64() { let args: Vec = vec![Arc::new(Float64Array::from(vec![ diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index c7c0d9b5a656..e1bcf33b8254 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -579,7 +579,7 @@ enum ScalarFunction { // 35 was MD5 // 36 was NullIf // 37 was OctetLength - Random = 38; + // 38 was Random // 39 was RegexpReplace // 40 was Repeat // 41 was Replace @@ -650,7 +650,7 @@ enum ScalarFunction { // 108 was ArrayReplaceN // 109 was ArrayRemoveAll // 110 was ArrayReplaceAll - Nanvl = 111; + // 111 was Nanvl // 112 was Flatten // 113 was IsNan // 114 was Iszero diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index c8a1fba40765..7beaeef0e58b 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22797,10 +22797,8 @@ impl serde::Serialize for ScalarFunction { Self::Concat => "Concat", Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", - Self::Random => "Random", Self::Coalesce => "Coalesce", Self::Factorial => "Factorial", - Self::Nanvl => "Nanvl", Self::EndsWith => "EndsWith", }; serializer.serialize_str(variant) @@ -22819,10 +22817,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Concat", "ConcatWithSeparator", "InitCap", - "Random", "Coalesce", "Factorial", - "Nanvl", "EndsWith", ]; @@ -22870,10 +22866,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Concat" => Ok(ScalarFunction::Concat), "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), - "Random" => Ok(ScalarFunction::Random), "Coalesce" => Ok(ScalarFunction::Coalesce), "Factorial" => Ok(ScalarFunction::Factorial), - "Nanvl" => Ok(ScalarFunction::Nanvl), "EndsWith" => Ok(ScalarFunction::EndsWith), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index facf24219810..042c794e19de 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2878,7 +2878,7 @@ pub enum ScalarFunction { /// 35 was MD5 /// 36 was NullIf /// 37 was OctetLength - Random = 38, + /// 38 was Random /// 39 was RegexpReplace /// 40 was Repeat /// 41 was Replace @@ -2949,7 +2949,7 @@ pub enum ScalarFunction { /// 108 was ArrayReplaceN /// 109 was ArrayRemoveAll /// 110 was ArrayReplaceAll - Nanvl = 111, + /// 111 was Nanvl /// 112 was Flatten /// 113 was IsNan /// 114 was Iszero @@ -2992,10 +2992,8 @@ impl ScalarFunction { ScalarFunction::Concat => "Concat", ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", - ScalarFunction::Random => "Random", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Factorial => "Factorial", - ScalarFunction::Nanvl => "Nanvl", ScalarFunction::EndsWith => "EndsWith", } } @@ -3008,10 +3006,8 @@ impl ScalarFunction { "Concat" => Some(Self::Concat), "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), - "Random" => Some(Self::Random), "Coalesce" => Some(Self::Coalesce), "Factorial" => Some(Self::Factorial), - "Nanvl" => Some(Self::Nanvl), "EndsWith" => Some(Self::EndsWith), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index e9eb53e45199..057690aacee6 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -41,9 +41,8 @@ use datafusion_expr::{ expr::{self, InList, Sort, WindowFunction}, factorial, initcap, logical_plan::{PlanType, StringifiedPlan}, - nanvl, random, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, - BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, - GroupingSet, + AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, + Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -426,9 +425,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, - ScalarFunction::Random => Self::Random, ScalarFunction::Coalesce => Self::Coalesce, - ScalarFunction::Nanvl => Self::Nanvl, } } } @@ -1298,7 +1295,6 @@ pub fn parse_expr( ScalarFunction::InitCap => { Ok(initcap(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Random => Ok(random()), ScalarFunction::Concat => { Ok(concat_expr(parse_exprs(args, registry, codec)?)) } @@ -1312,10 +1308,6 @@ pub fn parse_expr( ScalarFunction::Coalesce => { Ok(coalesce(parse_exprs(args, registry, codec)?)) } - ScalarFunction::Nanvl => Ok(nanvl( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), } } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ed5e7a302b20..358eea785713 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1414,9 +1414,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, - BuiltinScalarFunction::Random => Self::Random, BuiltinScalarFunction::Coalesce => Self::Coalesce, - BuiltinScalarFunction::Nanvl => Self::Nanvl, }; Ok(scalar_function) From 463133ea037c46b61d195fa5a890d73ad9d681c0 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Tue, 9 Apr 2024 12:33:31 -0400 Subject: [PATCH 2/2] Remove extraneous test udf. --- datafusion/physical-expr/src/functions.rs | 29 +---------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index db3e53246e13..c237e2070675 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -286,13 +286,12 @@ mod tests { datatypes::Field, record_batch::RecordBatch, }; - use std::any::Any; use datafusion_common::cast::as_uint64_array; use datafusion_common::{internal_err, plan_err}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::functions::data_types; - use datafusion_expr::{ScalarUDFImpl, Signature}; + use datafusion_expr::Signature; use crate::expressions::lit; use crate::expressions::try_cast; @@ -539,32 +538,6 @@ mod tests { Ok(()) } - #[derive(Debug)] - struct EmptyArgsUDF { - signature: Signature, - } - - impl ScalarUDFImpl for EmptyArgsUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "EmptyArgsUDF" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Utf8) - } - - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) - } - } - // Helper function just for testing. // Returns `expressions` coerced to types compatible with // `signature`, if possible.