Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Nanvl and random functions to datafusion-functions #10017

Merged
merged 2 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 8 additions & 7 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::logical_plan::builder::table_scan_with_filters;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
expr, table_scan, BuiltinScalarFunction, Cast, ColumnarValue, Expr, ExprSchemable,
LogicalPlan, LogicalPlanBuilder, ScalarUDF, Volatility,
expr, table_scan, Cast, ColumnarValue, Expr, ExprSchemable, LogicalPlan,
LogicalPlanBuilder, ScalarUDF, Volatility,
};
use datafusion_functions::math;
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions};
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
use std::sync::Arc;
Expand Down Expand Up @@ -383,17 +384,17 @@ fn test_const_evaluator_scalar_functions() {

// volatile / stable functions should not be evaluated
// rand() + (1 + 2) --> rand() + 3
let fun = BuiltinScalarFunction::Random;
assert_eq!(fun.volatility(), Volatility::Volatile);
let rand = Expr::ScalarFunction(ScalarFunction::new(fun, vec![]));
let fun = math::random();
assert_eq!(fun.signature().volatility, Volatility::Volatile);
let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
let expr = rand.clone() + (lit(1) + lit(2));
let expected = rand + lit(3);
test_evaluate(expr, expected);

// parenthesization matters: can't rewrite
// (rand() + 1) + 2 --> (rand() + 1) + 2)
let fun = BuiltinScalarFunction::Random;
let rand = Expr::ScalarFunction(ScalarFunction::new(fun, vec![]));
let fun = math::random();
let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
let expr = (rand + lit(1)) + lit(2);
test_evaluate(expr.clone(), expr);
}
Expand Down
21 changes: 0 additions & 21 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ pub enum BuiltinScalarFunction {
Exp,
/// factorial
Factorial,
/// nanvl
Nanvl,
// string functions
/// concat
Concat,
Expand All @@ -56,8 +54,6 @@ pub enum BuiltinScalarFunction {
EndsWith,
/// initcap
InitCap,
/// random
Random,
}

/// Maps the sql function name to `BuiltinScalarFunction`
Expand Down Expand Up @@ -114,14 +110,10 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
BuiltinScalarFunction::Exp => Volatility::Immutable,
BuiltinScalarFunction::Factorial => Volatility::Immutable,
BuiltinScalarFunction::Nanvl => Volatility::Immutable,
BuiltinScalarFunction::Concat => Volatility::Immutable,
BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable,
BuiltinScalarFunction::EndsWith => Volatility::Immutable,
BuiltinScalarFunction::InitCap => Volatility::Immutable,

// Volatile builtin functions
BuiltinScalarFunction::Random => Volatility::Volatile,
}
}

Expand Down Expand Up @@ -152,16 +144,10 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::InitCap => {
utf8_to_str_type(&input_expr_types[0], "initcap")
}
BuiltinScalarFunction::Random => Ok(Float64),
BuiltinScalarFunction::EndsWith => Ok(Boolean),

BuiltinScalarFunction::Factorial => Ok(Int64),

BuiltinScalarFunction::Nanvl => match &input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
},

BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => {
match input_expr_types[0] {
Float32 => Ok(Float32),
Expand Down Expand Up @@ -199,11 +185,6 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()),
BuiltinScalarFunction::Nanvl => Signature::one_of(
vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
self.volatility(),
),
BuiltinScalarFunction::Factorial => {
Signature::uniform(1, vec![Int64], self.volatility())
}
Expand Down Expand Up @@ -240,8 +221,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Ceil => &["ceil"],
BuiltinScalarFunction::Exp => &["exp"],
BuiltinScalarFunction::Factorial => &["factorial"],
BuiltinScalarFunction::Nanvl => &["nanvl"],
BuiltinScalarFunction::Random => &["random"],

// conditional functions
BuiltinScalarFunction::Coalesce => &["coalesce"],
Expand Down
11 changes: 2 additions & 9 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1903,8 +1903,8 @@ mod test {
use crate::expr::Cast;
use crate::expr_fn::col;
use crate::{
case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ScalarFunctionDefinition,
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
case, lit, ColumnarValue, Expr, ScalarFunctionDefinition, ScalarUDF,
ScalarUDFImpl, Signature, Volatility,
};
use arrow::datatypes::DataType;
use datafusion_common::Column;
Expand Down Expand Up @@ -2018,13 +2018,6 @@ mod test {

#[test]
fn test_is_volatile_scalar_func_definition() {
// BuiltIn
assert!(
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random)
.is_volatile()
.unwrap()
);

// UDF
#[derive(Debug)]
struct TestScalarUDF {
Expand Down
7 changes: 0 additions & 7 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,6 @@ pub fn concat_ws(sep: Expr, values: Vec<Expr>) -> Expr {
))
}

/// Returns a random value in the range 0.0 <= x < 1.0
pub fn random() -> Expr {
Expr::ScalarFunction(ScalarFunction::new(BuiltinScalarFunction::Random, vec![]))
}

/// Returns the approximate number of distinct input values.
/// This function provides an approximation of count(DISTINCT x).
/// Zero is returned if all input values are null.
Expand Down Expand Up @@ -550,7 +545,6 @@ nary_scalar_expr!(
"concatenates several strings, placing a seperator between each one"
);
nary_scalar_expr!(Concat, concat_expr, "concatenates several strings");
scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y");

/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
pub fn case(expr: Expr) -> CaseBuilder {
Expand Down Expand Up @@ -922,7 +916,6 @@ mod test {
test_unary_scalar_expr!(Factorial, factorial);
test_unary_scalar_expr!(Ceil, ceil);
test_unary_scalar_expr!(Exp, exp);
test_scalar_expr!(Nanvl, nanvl, x, y);

test_scalar_expr!(InitCap, initcap, string);
test_scalar_expr!(EndsWith, ends_with, string, characters);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub enum Volatility {
Stable,
/// A volatile function may change the return value from evaluation to evaluation.
/// Multiple invocations of a volatile function may return different results when used in the
/// same query. An example of this is [super::BuiltinScalarFunction::Random]. DataFusion
/// same query. An example of this is the random() function. DataFusion
/// can not evaluate such functions during planning.
/// In the query `select col1, random() from t1`, `random()` function will be evaluated
/// for each output row, resulting in a unique random value for each row.
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ hex = { version = "0.4", optional = true }
itertools = { workspace = true }
log = { workspace = true }
md-5 = { version = "^0.10.0", optional = true }
rand = { workspace = true }
regex = { version = "1.8", optional = true }
sha2 = { version = "^0.10.1", optional = true }
unicode-segmentation = { version = "^1.7.1", optional = true }
Expand Down
16 changes: 16 additions & 0 deletions datafusion/functions/src/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ pub mod iszero;
pub mod lcm;
pub mod log;
pub mod nans;
pub mod nanvl;
pub mod pi;
pub mod power;
pub mod random;
pub mod round;
pub mod trunc;

Expand All @@ -55,9 +57,11 @@ make_udf_function!(lcm::LcmFunc, LCM, lcm);
make_math_unary_udf!(LnFunc, LN, ln, ln, Some(vec![Some(true)]));
make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)]));
make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)]));
make_udf_function!(nanvl::NanvlFunc, NANVL, nanvl);
make_udf_function!(pi::PiFunc, PI, pi);
make_udf_function!(power::PowerFunc, POWER, power);
make_math_unary_udf!(RadiansFunc, RADIANS, radians, to_radians, None);
make_udf_function!(random::RandomFunc, RANDOM, random);
make_udf_function!(round::RoundFunc, ROUND, round);
make_math_unary_udf!(SignumFunc, SIGNUM, signum, signum, None);
make_math_unary_udf!(SinFunc, SIN, sin, sin, None);
Expand Down Expand Up @@ -180,6 +184,11 @@ pub mod expr_fn {
super::log10().call(vec![num])
}

#[doc = "returns x if x is not NaN otherwise returns y"]
pub fn nanvl(x: Expr, y: Expr) -> Expr {
super::nanvl().call(vec![x, y])
}

#[doc = "Returns an approximate value of π"]
pub fn pi() -> Expr {
super::pi().call(vec![])
Expand All @@ -195,6 +204,11 @@ pub mod expr_fn {
super::radians().call(vec![num])
}

#[doc = "Returns a random value in the range 0.0 <= x < 1.0"]
pub fn random() -> Expr {
super::random().call(vec![])
}

#[doc = "round to nearest integer"]
pub fn round(args: Vec<Expr>) -> Expr {
super::round().call(args)
Expand Down Expand Up @@ -261,9 +275,11 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
log(),
log2(),
log10(),
nanvl(),
pi(),
power(),
radians(),
random(),
round(),
signum(),
sin(),
Expand Down
Loading
Loading