Skip to content

Commit

Permalink
Refactor: Unify Expr::ScalarFunction and Expr::ScalarUDF, introdu…
Browse files Browse the repository at this point in the history
…ce unresolved functions by name (#8258)

* Refactor Expr::ScalarFunction

* Remove Expr::ScalarUDF

* review comments

* make name() return &str

* fix fmt

* fix after merge
  • Loading branch information
2010YOUY01 authored Nov 26, 2023
1 parent b648d4e commit f8dcc64
Show file tree
Hide file tree
Showing 21 changed files with 419 additions and 271 deletions.
54 changes: 30 additions & 24 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ use super::PartitionedFile;
use crate::datasource::listing::ListingTableUrl;
use crate::execution::context::SessionState;
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_common::{Column, DFField, DFSchema, DataFusionError};
use datafusion_expr::expr::ScalarUDF;
use datafusion_expr::{Expr, Volatility};
use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError};
use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility};
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;
use object_store::path::Path;
Expand All @@ -54,13 +53,13 @@ use object_store::{ObjectMeta, ObjectStore};
pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
let mut is_applicable = true;
expr.apply(&mut |expr| {
Ok(match expr {
match expr {
Expr::Column(Column { ref name, .. }) => {
is_applicable &= col_names.contains(name);
if is_applicable {
VisitRecursion::Skip
Ok(VisitRecursion::Skip)
} else {
VisitRecursion::Stop
Ok(VisitRecursion::Stop)
}
}
Expr::Literal(_)
Expand Down Expand Up @@ -89,25 +88,32 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
| Expr::ScalarSubquery(_)
| Expr::GetIndexedField { .. }
| Expr::GroupingSet(_)
| Expr::Case { .. } => VisitRecursion::Continue,
| Expr::Case { .. } => Ok(VisitRecursion::Continue),

Expr::ScalarFunction(scalar_function) => {
match scalar_function.fun.volatility() {
Volatility::Immutable => VisitRecursion::Continue,
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
is_applicable = false;
VisitRecursion::Stop
match &scalar_function.func_def {
ScalarFunctionDefinition::BuiltIn { fun, .. } => {
match fun.volatility() {
Volatility::Immutable => Ok(VisitRecursion::Continue),
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
is_applicable = false;
Ok(VisitRecursion::Stop)
}
}
}
}
}
Expr::ScalarUDF(ScalarUDF { fun, .. }) => {
match fun.signature().volatility {
Volatility::Immutable => VisitRecursion::Continue,
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
is_applicable = false;
VisitRecursion::Stop
ScalarFunctionDefinition::UDF(fun) => {
match fun.signature().volatility {
Volatility::Immutable => Ok(VisitRecursion::Continue),
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
is_applicable = false;
Ok(VisitRecursion::Stop)
}
}
}
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
}
}
}
Expand All @@ -123,9 +129,9 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
| Expr::Wildcard { .. }
| Expr::Placeholder(_) => {
is_applicable = false;
VisitRecursion::Stop
Ok(VisitRecursion::Stop)
}
})
}
})
.unwrap();
is_applicable
Expand Down
18 changes: 10 additions & 8 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ use datafusion_common::{
use datafusion_expr::dml::{CopyOptions, CopyTo};
use datafusion_expr::expr::{
self, AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Cast,
GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, ScalarUDF, TryCast,
WindowFunction,
GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, WindowFunction,
};
use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols};
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::{
DescribeTable, DmlStatement, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
DescribeTable, DmlStatement, ScalarFunctionDefinition, StringifiedPlan, WindowFrame,
WindowFrameBound, WriteOp,
};
use datafusion_physical_expr::expressions::Literal;
use datafusion_sql::utils::window_expr_common_partition_keys;
Expand Down Expand Up @@ -217,11 +217,13 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {

Ok(name)
}
Expr::ScalarFunction(func) => {
create_function_physical_name(&func.fun.to_string(), false, &func.args)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
create_function_physical_name(fun.name(), false, args)
Expr::ScalarFunction(expr::ScalarFunction { func_def, args }) => {
// function should be resolved during `AnalyzerRule`s
if let ScalarFunctionDefinition::Name(_) = func_def {
return internal_err!("Function `Expr` with name should be resolved.");
}

create_function_physical_name(func_def.name(), false, args)
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
create_function_physical_name(&fun.to_string(), false, args)
Expand Down
79 changes: 47 additions & 32 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,8 @@ pub enum Expr {
TryCast(TryCast),
/// A sort expression, that can be used to sort values.
Sort(Sort),
/// Represents the call of a built-in scalar function with a set of arguments.
/// Represents the call of a scalar function with a set of arguments.
ScalarFunction(ScalarFunction),
/// Represents the call of a user-defined scalar function with arguments.
ScalarUDF(ScalarUDF),
/// Represents the call of an aggregate built-in function with arguments.
AggregateFunction(AggregateFunction),
/// Represents the call of a window function with arguments.
Expand Down Expand Up @@ -338,37 +336,61 @@ impl Between {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
/// Defines which implementation of a function for DataFusion to call.
pub enum ScalarFunctionDefinition {
/// Resolved to a `BuiltinScalarFunction`
/// There is plan to migrate `BuiltinScalarFunction` to UDF-based implementation (issue#8045)
/// This variant is planned to be removed in long term
BuiltIn {
fun: built_in_function::BuiltinScalarFunction,
name: Arc<str>,
},
/// Resolved to a user defined function
UDF(Arc<crate::ScalarUDF>),
/// A scalar function constructed with name. This variant can not be executed directly
/// and instead must be resolved to one of the other variants prior to physical planning.
Name(Arc<str>),
}

/// ScalarFunction expression invokes a built-in scalar function
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct ScalarFunction {
/// The function
pub fun: built_in_function::BuiltinScalarFunction,
pub func_def: ScalarFunctionDefinition,
/// List of expressions to feed to the functions as arguments
pub args: Vec<Expr>,
}

impl ScalarFunctionDefinition {
/// Function's name for display
pub fn name(&self) -> &str {
match self {
ScalarFunctionDefinition::BuiltIn { name, .. } => name.as_ref(),
ScalarFunctionDefinition::UDF(udf) => udf.name(),
ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(),
}
}
}

impl ScalarFunction {
/// Create a new ScalarFunction expression
pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec<Expr>) -> Self {
Self { fun, args }
Self {
func_def: ScalarFunctionDefinition::BuiltIn {
fun,
name: Arc::from(fun.to_string()),
},
args,
}
}
}

/// ScalarUDF expression invokes a user-defined scalar function [`ScalarUDF`]
///
/// [`ScalarUDF`]: crate::ScalarUDF
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct ScalarUDF {
/// The function
pub fun: Arc<crate::ScalarUDF>,
/// List of expressions to feed to the functions as arguments
pub args: Vec<Expr>,
}

impl ScalarUDF {
/// Create a new ScalarUDF expression
pub fn new(fun: Arc<crate::ScalarUDF>, args: Vec<Expr>) -> Self {
Self { fun, args }
/// Create a new ScalarFunction expression with a user-defined function (UDF)
pub fn new_udf(udf: Arc<crate::ScalarUDF>, args: Vec<Expr>) -> Self {
Self {
func_def: ScalarFunctionDefinition::UDF(udf),
args,
}
}
}

Expand Down Expand Up @@ -736,7 +758,6 @@ impl Expr {
Expr::Placeholder(_) => "Placeholder",
Expr::ScalarFunction(..) => "ScalarFunction",
Expr::ScalarSubquery { .. } => "ScalarSubquery",
Expr::ScalarUDF(..) => "ScalarUDF",
Expr::ScalarVariable(..) => "ScalarVariable",
Expr::Sort { .. } => "Sort",
Expr::TryCast { .. } => "TryCast",
Expand Down Expand Up @@ -1198,11 +1219,8 @@ impl fmt::Display for Expr {
write!(f, " NULLS LAST")
}
}
Expr::ScalarFunction(func) => {
fmt_function(f, &func.fun.to_string(), false, &func.args, true)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
fmt_function(f, fun.name(), false, args, true)
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
fmt_function(f, func_def.name(), false, args, true)
}
Expr::WindowFunction(WindowFunction {
fun,
Expand Down Expand Up @@ -1534,11 +1552,8 @@ fn create_name(e: &Expr) -> Result<String> {
}
}
}
Expr::ScalarFunction(func) => {
create_function_name(&func.fun.to_string(), false, &func.args)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
create_function_name(fun.name(), false, args)
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
create_function_name(func_def.name(), false, args)
}
Expr::WindowFunction(WindowFunction {
fun,
Expand Down
Loading

0 comments on commit f8dcc64

Please sign in to comment.