Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plugable handler for CREATE FUNCTION #9333

Merged
merged 25 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c87b609
Add plugable function factory
milenkovicm Feb 21, 2024
be5b233
cover `DROP FUNCTION` as well ...
milenkovicm Feb 24, 2024
76dc3bf
update documentation
milenkovicm Feb 28, 2024
f265d68
fix doc test
milenkovicm Feb 28, 2024
f084201
Address PR comments (code organization)
milenkovicm Feb 29, 2024
a8d54b6
Address PR comments (factory interface)
milenkovicm Feb 29, 2024
262f0f6
fix test after rebase
milenkovicm Feb 29, 2024
fe63c31
`remove`'s gone from the trait ...
milenkovicm Mar 2, 2024
b975df7
Rename FunctionDefinition and export it ...
milenkovicm Mar 2, 2024
9d1d715
Update datafusion/expr/src/logical_plan/ddl.rs
milenkovicm Mar 2, 2024
7e7d896
Update datafusion/core/src/execution/context/mod.rs
milenkovicm Mar 2, 2024
0430c11
Update datafusion/core/tests/user_defined/user_defined_scalar_functio…
milenkovicm Mar 2, 2024
210b194
Update datafusion/expr/src/logical_plan/ddl.rs
milenkovicm Mar 2, 2024
1d5d739
resolve part of follow up comments
milenkovicm Mar 2, 2024
84b0fbd
Qualified functions are not supported anymore
milenkovicm Mar 2, 2024
b642570
update docs and todos
milenkovicm Mar 2, 2024
b8f8991
fix clippy
milenkovicm Mar 2, 2024
5a9ad09
address additional comments
milenkovicm Mar 2, 2024
58479e3
Add sqllogicteset for CREATE/DROP function
alamb Mar 3, 2024
83acc8c
Add coverage for DROP FUNCTION IF EXISTS
alamb Mar 3, 2024
383602c
fix multiline error
alamb Mar 3, 2024
00b8058
revert dialect back to generic in test ...
milenkovicm Mar 5, 2024
f27d800
Merge remote-tracking branch 'apache/main' into create_function_factory
alamb Mar 5, 2024
d7e37ed
Merge remote-tracking branch 'apache/main' into create_function_factory
alamb Mar 5, 2024
8a0f42f
fmt
alamb Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 94 additions & 3 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ use crate::datasource::{
};
use crate::error::{DataFusionError, Result};
use crate::logical_expr::{
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateMemoryTable,
CreateView, DropCatalogSchema, DropTable, DropView, Explain, LogicalPlan,
LogicalPlanBuilder, SetVariable, TableSource, TableType, UNNAMED_TABLE,
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,
CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable, DropView,
Explain, LogicalPlan, LogicalPlanBuilder, SetVariable, TableSource, TableType,
UNNAMED_TABLE,
};
use crate::optimizer::OptimizerRule;
use datafusion_sql::{
Expand Down Expand Up @@ -489,6 +490,8 @@ impl SessionContext {
DdlStatement::DropTable(cmd) => self.drop_table(cmd).await,
DdlStatement::DropView(cmd) => self.drop_view(cmd).await,
DdlStatement::DropCatalogSchema(cmd) => self.drop_schema(cmd).await,
DdlStatement::CreateFunction(cmd) => self.create_function(cmd).await,
DdlStatement::DropFunction(cmd) => self.drop_function(cmd).await,
},
// TODO what about the other statements (like TransactionStart and TransactionEnd)
LogicalPlan::Statement(Statement::SetVariable(stmt)) => {
Expand Down Expand Up @@ -794,6 +797,55 @@ impl SessionContext {
Ok(false)
}

async fn create_function(&self, stmt: CreateFunction) -> Result<DataFrame> {
milenkovicm marked this conversation as resolved.
Show resolved Hide resolved
let function = {
let state = self.state.read().clone();
let function_factory = &state.function_factory;

match function_factory {
Some(f) => f.create(state.config(), stmt).await?,
_ => Err(DataFusionError::Configuration(
"Function factory has not been configured".into(),
))?,
}
};

match function {
RegisterFunction::Scalar(f) => {
self.state.write().register_udf(f)?;
}
RegisterFunction::Aggregate(f) => {
self.state.write().register_udaf(f)?;
}
RegisterFunction::Window(f) => {
self.state.write().register_udwf(f)?;
}
RegisterFunction::Table(name, f) => self.register_udtf(&name, f),
};

self.return_empty_dataframe()
}

async fn drop_function(&self, stmt: DropFunction) -> Result<DataFrame> {
// we don't know function type at this point
// decision has been made to drop all functions
let mut dropped = false;
dropped |= self.state.write().deregister_udf(&stmt.name)?.is_some();
dropped |= self.state.write().deregister_udaf(&stmt.name)?.is_some();
dropped |= self.state.write().deregister_udwf(&stmt.name)?.is_some();

// DROP FUNCTION IF EXISTS drops the specified function only if that
// function exists and in this way, it avoids error. While the DROP FUNCTION
// statement also performs the same function, it throws an
// error if the function does not exist.

if !stmt.if_exists && !dropped {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to add the if exist test like DROP FUNCTION if exists abs; so the logic here is covered.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good call -- I added a test to sqllogictest

exec_err!("Function does not exist")
} else {
self.return_empty_dataframe()
}
}

/// Registers a variable provider within this context.
pub fn register_variable(
&self,
Expand Down Expand Up @@ -1261,7 +1313,30 @@ impl QueryPlanner for DefaultQueryPlanner {
.await
}
}
/// A pluggable interface to handle `CREATE FUNCTION` statements
/// and interact with [SessionState] to registers new udf, udaf or udwf.

#[async_trait]
pub trait FunctionFactory: Sync + Send {
milenkovicm marked this conversation as resolved.
Show resolved Hide resolved
/// Handles creation of user defined function specified in [CreateFunction] statement
async fn create(
&self,
state: &SessionConfig,
statement: CreateFunction,
) -> Result<RegisterFunction>;
}

/// Type of function to create
pub enum RegisterFunction {
/// Scalar user defined function
Scalar(Arc<ScalarUDF>),
/// Aggregate user defined function
Aggregate(Arc<AggregateUDF>),
/// Window user defined function
Window(Arc<WindowUDF>),
/// Table user defined function
Table(String, Arc<dyn TableFunctionImpl>),
}
/// Execution context for registering data sources and executing queries.
/// See [`SessionContext`] for a higher level API.
///
Expand Down Expand Up @@ -1306,6 +1381,12 @@ pub struct SessionState {
table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
/// Runtime environment
runtime_env: Arc<RuntimeEnv>,

/// [FunctionFactory] to support pluggable user defined function handler.
///
/// It will be invoked on `CREATE FUNCTION` statements.
/// thus, changing dialect o PostgreSql is required
function_factory: Option<Arc<dyn FunctionFactory>>,
}

impl Debug for SessionState {
Expand Down Expand Up @@ -1392,6 +1473,7 @@ impl SessionState {
execution_props: ExecutionProps::new(),
runtime_env: runtime,
table_factories,
function_factory: None,
};

// register built in functions
Expand Down Expand Up @@ -1568,6 +1650,15 @@ impl SessionState {
self
}

/// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements
pub fn with_function_factory(
mut self,
function_factory: Arc<dyn FunctionFactory>,
) -> Self {
self.function_factory = Some(function_factory);
self
}

/// Replace the extension [`SerializerRegistry`]
pub fn with_serializer_registry(
mut self,
Expand Down
130 changes: 128 additions & 2 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

use arrow::compute::kernels::numeric::add;
use arrow_array::{
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array,
Array, ArrayRef, ArrowNativeTypeOp, Float32Array, Float64Array, Int32Array,
RecordBatch, UInt8Array,
};
use arrow_schema::DataType::Float64;
use arrow_schema::{DataType, Field, Schema};
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
use datafusion_common::cast::as_float64_array;
Expand All @@ -31,10 +33,12 @@ use datafusion_common::{
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable,
create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, ExprSchemable,
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use parking_lot::Mutex;

use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use rand::{thread_rng, Rng};
use std::any::Any;
use std::iter;
Expand Down Expand Up @@ -735,6 +739,128 @@ async fn verify_udf_return_type() -> Result<()> {
Ok(())
}

#[derive(Debug, Default)]
struct MockFunctionFactory {
pub captured_expr: Mutex<Option<Expr>>,
}

#[async_trait::async_trait]
impl FunctionFactory for MockFunctionFactory {
#[doc = r" Crates and registers a function from [CreateFunction] statement"]
#[must_use]
#[allow(clippy::type_complexity, clippy::type_repetition_in_bounds)]
async fn create(
&self,
_config: &SessionConfig,
statement: CreateFunction,
) -> datafusion::error::Result<RegisterFunction> {
// In this example, we always create a function that adds its arguments
// with the name specified in `CREATE FUNCTION`. In a real implementation
// the body of the created UDF would also likely be a function of the contents
// of the `CreateFunction`
let mock_add = Arc::new(|args: &[datafusion_expr::ColumnarValue]| {
let args = datafusion_expr::ColumnarValue::values_to_arrays(args)?;
let base =
datafusion_common::cast::as_float64_array(&args[0]).expect("cast failed");
let exponent =
datafusion_common::cast::as_float64_array(&args[1]).expect("cast failed");

let array = base
.iter()
.zip(exponent.iter())
.map(|(base, exponent)| match (base, exponent) {
(Some(base), Some(exponent)) => Some(base.add_wrapping(exponent)),
_ => None,
})
.collect::<arrow_array::Float64Array>();
Ok(datafusion_expr::ColumnarValue::from(
Arc::new(array) as arrow_array::ArrayRef
))
});

let args = statement.args.unwrap();
let mock_udf = create_udf(
&statement.name,
vec![args[0].data_type.clone(), args[1].data_type.clone()],
Arc::new(statement.return_type.unwrap()),
datafusion_expr::Volatility::Immutable,
mock_add,
);

// capture expression so we can verify
// it has been parsed
*self.captured_expr.lock() = statement.params.return_;

Ok(RegisterFunction::Scalar(Arc::new(mock_udf)))
}
}

#[tokio::test]
async fn create_scalar_function_from_sql_statement() -> Result<()> {
let function_factory = Arc::new(MockFunctionFactory::default());
let runtime_config = RuntimeConfig::new();
let runtime_environment = RuntimeEnv::new(runtime_config)?;

let session_config = SessionConfig::new();
let state =
SessionState::new_with_config_rt(session_config, Arc::new(runtime_environment))
.with_function_factory(function_factory.clone());

let ctx = SessionContext::new_with_state(state);
let options = SQLOptions::new().with_allow_ddl(false);

let sql = r#"
CREATE FUNCTION better_add(DOUBLE, DOUBLE)
RETURNS DOUBLE
RETURN $1 + $2
"#;

// try to `create function` when sql options have allow ddl disabled
assert!(ctx.sql_with_options(sql, options).await.is_err());

// Create the `better_add` function dynamically via CREATE FUNCTION statement
assert!(ctx.sql(sql).await.is_ok());
// try to `drop function` when sql options have allow ddl disabled
assert!(ctx
.sql_with_options("drop function better_add", options)
.await
.is_err());

ctx.sql("select better_add(2.0, 2.0)").await?.show().await?;

// check if we sql expr has been converted to datafusion expr
let captured_expression = function_factory.captured_expr.lock().clone().unwrap();
assert_eq!("$1 + $2", captured_expression.to_string());

// statement drops function
assert!(ctx.sql("drop function better_add").await.is_ok());
// no function, it panics
assert!(ctx.sql("drop function better_add").await.is_err());
// no function, it dies not care
assert!(ctx.sql("drop function if exists better_add").await.is_ok());
// query should fail as there is no function
assert!(ctx.sql("select better_add(2.0, 2.0)").await.is_err());

milenkovicm marked this conversation as resolved.
Show resolved Hide resolved
// tests expression parsing
// if expression is not correct
let bad_expression_sql = r#"
CREATE FUNCTION bad_expression_fun(DOUBLE, DOUBLE)
RETURNS DOUBLE
RETURN $1 $3
"#;
assert!(ctx.sql(bad_expression_sql).await.is_err());

// tests bad function definition
let bad_definition_sql = r#"
CREATE FUNCTION bad_definition_fun(DOUBLE, DOUBLE)
RET BAD_TYPE
RETURN $1 + $3
"#;
assert!(ctx.sql(bad_definition_sql).await.is_err());

Ok(())
}

fn create_udf_context() -> SessionContext {
let ctx = SessionContext::new();
// register a custom UDF
Expand Down
Loading
Loading